jiuhai commited on
Commit
4cd1d55
·
verified ·
1 Parent(s): a3c20e1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -34
  2. .gitignore +41 -0
  3. LICENSE +381 -0
  4. README.md +120 -0
  5. packages/ltx-core/README.md +409 -0
  6. packages/ltx-core/src/ltx_core/conditioning/types/__init__.py +13 -0
  7. packages/ltx-core/src/ltx_core/conditioning/types/attention_strength_wrapper.py +71 -0
  8. packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py +70 -0
  9. packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py +44 -0
  10. packages/ltx-core/src/ltx_core/conditioning/types/reference_video_cond.py +91 -0
  11. packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-312.pyc +0 -0
  12. packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-312.pyc +0 -0
  13. packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py +10 -0
  14. packages/ltx-core/src/ltx_core/model/common/__init__.py +9 -0
  15. packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-312.pyc +0 -0
  16. packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-312.pyc +0 -0
  17. packages/ltx-core/src/ltx_core/model/common/normalization.py +59 -0
  18. packages/ltx-core/src/ltx_core/model/transformer/__init__.py +18 -0
  19. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-312.pyc +0 -0
  20. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-312.pyc +0 -0
  21. packages/ltx-core/src/ltx_core/model/transformer/adaln.py +45 -0
  22. packages/ltx-core/src/ltx_core/model/transformer/attention.py +249 -0
  23. packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py +10 -0
  24. packages/ltx-core/src/ltx_core/model/transformer/modality.py +40 -0
  25. packages/ltx-core/src/ltx_core/model/transformer/model.py +486 -0
  26. packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py +152 -0
  27. packages/ltx-core/src/ltx_core/model/transformer/rope.py +204 -0
  28. packages/ltx-core/src/ltx_core/model/transformer/text_projection.py +38 -0
  29. packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py +143 -0
  30. packages/ltx-core/src/ltx_core/model/transformer/transformer.py +398 -0
  31. packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py +297 -0
  32. packages/ltx-core/src/ltx_core/model/video_vae/__init__.py +24 -0
  33. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-312.pyc +0 -0
  34. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-312.pyc +0 -0
  35. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-312.pyc +0 -0
  36. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-312.pyc +0 -0
  37. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-312.pyc +0 -0
  38. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-312.pyc +0 -0
  39. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-312.pyc +0 -0
  40. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/tiling.cpython-312.pyc +0 -0
  41. packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-312.pyc +0 -0
  42. packages/ltx-core/src/ltx_core/model/video_vae/convolution.py +317 -0
  43. packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py +79 -0
  44. packages/ltx-core/src/ltx_core/model/video_vae/resnet.py +277 -0
  45. packages/ltx-core/src/ltx_core/model/video_vae/tiling.py +291 -0
  46. packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py +1219 -0
  47. packages/ltx-core/src/ltx_core/quantization/__pycache__/__init__.cpython-312.pyc +0 -0
  48. packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_cast.cpython-312.pyc +0 -0
  49. packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_scaled_mm.cpython-312.pyc +0 -0
  50. packages/ltx-core/src/ltx_core/quantization/__pycache__/policy.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -1,35 +1,9 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.gif filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ *.sft filter=lfs diff=lfs merge=lfs -text
4
+ *.pt filter=lfs diff=lfs merge=lfs -text
5
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
6
+ *.png filter=lfs diff=lfs merge=lfs -text
7
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
8
+ *.jpg filter=lfs diff=lfs merge=lfs -text
9
+ *.webp filter=lfs diff=lfs merge=lfs -text
 
 
 
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ checkpoints/
8
+ *.egg-info
9
+
10
+ # Virtual environments
11
+ .venv
12
+ .python-version
13
+
14
+ # IDE settings
15
+ .idea/
16
+ .vscode/
17
+
18
+ # Other files
19
+ .DS_Store
20
+ tmp
21
+ .wandb
22
+
23
+ # Model checkpoints
24
+ *.ckpt
25
+ *.pt
26
+ *.safetensors
27
+ *.sft
28
+
29
+ # Media files
30
+ *.gif
31
+ *.heic
32
+ *.heif
33
+ *.jpg
34
+ *.jpeg
35
+ *.json
36
+ *.m4a
37
+ *.mov
38
+ *.mp4
39
+ *.png
40
+ *.wav
41
+ *.webp
LICENSE ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LTX-2 Community License Agreement
2
+ License date: January 5, 2026
3
+
4
+
5
+ By using or distributing any portion or element of LTX-2, you agree
6
+ to be bound by this Agreement.
7
+
8
+ 1. Definitions.
9
+
10
+ "Agreement" means the terms and conditions for the license, use,
11
+ reproduction, and distribution of LTX-2 and the Complementary
12
+ Materials, as specified in this document.
13
+
14
+ "Control" means the direct or indirect ownership of more than
15
+ fifty percent (50%) of the voting securities or other ownership
16
+ interests, or the power to direct the management and policies of
17
+ such Entity through voting rights, contract, or otherwise.
18
+
19
+ "Data" means a collection of information and/or content extracted
20
+ from the dataset used with LTX-2, including to train, pretrain,
21
+ or otherwise evaluate LTX-2. The Data is not licensed under this
22
+ Agreement.
23
+
24
+ "Derivatives of LTX-2" means all modifications to LTX-2, works
25
+ based on LTX-2, or any other model which is created or initialized
26
+ by transfer of patterns of the weights, parameters, activations or
27
+ output of LTX-2, to the other model, in order to cause the other
28
+ model to perform similarly to LTX-2, including – but not limited
29
+ to - distillation methods entailing the use of intermediate data
30
+ representations or methods based on the generation of synthetic
31
+ data by LTX-2 for training the other model. For clarity, Derivatives
32
+ of LTX-2 include: (i) any fine-tuned or adapted weights, parameters,
33
+ or checkpoints derived from LTX-2; (ii) derivative model architectures
34
+ that incorporate or are based upon LTX-2's architecture; and
35
+ (iii) any modified or extended versions of the Complementary
36
+ Materials. All intellectual property rights in Derivatives of LTX-2
37
+ shall be subject to the terms of this Agreement, and you may not
38
+ claim exclusive ownership rights in any Derivatives of LTX-2 that
39
+ would restrict the rights granted herein.
40
+
41
+ "Entity" means any individual, corporation, partnership, limited
42
+ liability company, or other legal entity. For purposes of this
43
+ Agreement, an Entity shall be deemed to include, on an aggregative
44
+ basis, all subsidiaries, affiliates, and other companies under
45
+ common Control with such Entity. When determining whether an Entity
46
+ meets any threshold under this Agreement (including revenue
47
+ thresholds), all subsidiaries, affiliates, and companies under
48
+ common Control shall be considered collectively.
49
+
50
+ "Harm" includes but is not limited to physical, mental,
51
+ psychological, financial and reputational damage, pain, or loss.
52
+
53
+ "Licensor" or "Lightricks" means the owner that is granting the
54
+ license under this Agreement. For the purposes of this Agreement,
55
+ the Licensor is Lightricks Ltd.
56
+
57
+ "LTX-2" means the large language models, text/image/video/audio/3D
58
+ generation models, and multimodal large language models and their
59
+ software and algorithms, including trained model weights, parameters
60
+ (including optimizer states), machine-learning model code,
61
+ inference-enabling code, training-enabling code, fine-tuning
62
+ enabling code, accompanying source code, scripts, documentation,
63
+ tutorials, examples, and all other elements of the foregoing
64
+ distributed and made publicly available by Lightricks (including,
65
+ for example, at https://github.com/Lightricks/LTX-2) for the LTX-2
66
+ model released on January 5, 2026. This license is applicable to
67
+ all LTX-2 versions released since January 5, 2026, and all future
68
+ releases of LTX-2 under this license.
69
+
70
+ "Output" means the results of operating LTX-2 as embodied in
71
+ informational content resulting therefrom.
72
+
73
+ "you" (or "your") means an individual or legal Entity licensing
74
+ LTX-2 in accordance with this Agreement and/or making use of LTX-2
75
+ for whichever purpose and in any field of use, including usage of
76
+ LTX-2 in an end-use application - e.g. chatbot, translator, image
77
+ generator.
78
+
79
+ 2. Grant of License. Subject to the terms and conditions of this
80
+ Agreement, you are granted a non-exclusive, worldwide,
81
+ non-transferable and royalty-free limited license under Licensor's
82
+ intellectual property or other rights owned by Licensor embodied
83
+ in LTX-2 to use, reproduce, prepare, distribute, publicly display,
84
+ publicly perform, sublicense, copy, create derivative works of,
85
+ and make modifications to LTX-2, for any purpose, subject to the
86
+ restrictions set forth in Attachment A; provided however, that
87
+ Entities with annual revenues of at least $10,000,000 (the
88
+ "Commercial Entities") are required to obtain a paid commercial
89
+ use license in order to use LTX-2 and Derivatives of LTX-2,
90
+ subject to the terms and provisions of a different license (the
91
+ "Commercial Use Agreement"), as will be provided by the Licensor.
92
+ Commercial Entities interested in such a commercial license are
93
+ required to [contact Licensor](https://ltx.io/model/licensing).
94
+ Any commercial use of LTX-2 or Derivatives of LTX-2 by the
95
+ Commercial Entities not in accordance with this Agreement and/or
96
+ the Commercial Use Agreement is strictly prohibited and shall be
97
+ deemed a material breach of this Agreement. Such material breach
98
+ will be subject, in addition to any license fees owed to Licensor
99
+ for the period such Commercial Entity used LTX-2 (as will be
100
+ determined by Licensor), to liquidated damages, which will be paid
101
+ to Licensor immediately upon demand, in an amount equal to double
102
+ the amount that would otherwise have been paid by you for the
103
+ relevant period of time. Such amount reflects a reasonable estimation
104
+ of the losses and administrative costs incurred due to such breach.
105
+ You agree and understand that this remedy does not limit the Licensor's
106
+ right to pursue other remedies available at law or equity.
107
+
108
+ 3. Distribution and Redistribution. You may host for third parties
109
+ remote access purposes (e.g. software-as-a-service), reproduce
110
+ and distribute copies of LTX-2 or Derivatives of LTX-2 thereof in
111
+ any medium, with or without modifications, provided that you meet
112
+ the following conditions:
113
+
114
+ (a) Use-based restrictions as referenced in paragraph 4 and all
115
+ provisions of Attachment A MUST be included as an enforceable
116
+ provision by you in any type of legal agreement (e.g. a
117
+ license) governing the use and/or distribution of LTX-2 or
118
+ Derivatives of LTX-2, and you shall give notice to subsequent
119
+ users you distribute to, that LTX-2 or Derivatives of LTX-2
120
+ are subject to paragraph 4 and Attachment A in their entirety,
121
+ including all use restrictions and acceptable use policies;
122
+
123
+ (b) You must provide any third party recipients of LTX-2 or
124
+ Derivatives of LTX-2 a copy of this Agreement, including all
125
+ attachments and use policies. Any Derivative of LTX-2 (as
126
+ defined in Section 1, including but not limited to fine-tuned
127
+ weights, modified training code, models trained on Outputs, or
128
+ any other derivative) must be distributed exclusively under
129
+ the terms of this Agreement with a complete copy of this
130
+ license included;
131
+
132
+ (c) You must cause any modified files to carry prominent notices
133
+ stating that you changed the files;
134
+
135
+ (d) You must retain all copyright, patent, trademark, and
136
+ attribution notices excluding those notices that do not
137
+ pertain to any part of LTX-2, Derivatives of LTX-2.
138
+
139
+ You may add your own copyright statement to your modifications and
140
+ may provide additional or different license terms and conditions -
141
+ respecting paragraph 3(a) - for use, reproduction, or distribution
142
+ of your modifications, or for any such Derivatives of LTX-2 as a
143
+ whole, provided your use, reproduction, and distribution of LTX-2
144
+ otherwise complies with the conditions stated in this Agreement,
145
+ and you provide a complete copy of this Agreement with any such
146
+ use, reproduction and distribution of LTX-2 and any Derivatives
147
+ thereof.
148
+
149
+ 4. Use-based restrictions. The restrictions set forth in Attachment A
150
+ are considered Use-based restrictions. Therefore, you cannot use
151
+ LTX-2 and the Derivatives of LTX-2 in violation of the specified
152
+ restricted uses. You may use LTX-2 subject to this Agreement,
153
+ including only for lawful purposes and in accordance with the
154
+ Agreement. "Use" may include creating any content with, fine-tuning,
155
+ updating, running, training, evaluating and/or re-parametrizing
156
+ LTX-2. You shall require all of your users who use LTX-2 or a
157
+ Derivative of LTX-2 to comply with the terms of this paragraph 4.
158
+
159
+ 5. The Output You Generate. Except as set forth herein, Licensor
160
+ claims no rights in the Output you generate using LTX-2. You are
161
+ accountable for input you insert into LTX-2, the Output you
162
+ generate and its subsequent uses. No use of the Output can
163
+ contravene any provision as stated in the Agreement.
164
+
165
+ 6. Updates and Runtime Restrictions. To the maximum extent permitted
166
+ by law, Licensor reserves the right to restrict (remotely or
167
+ otherwise) usage of LTX-2 in violation of this Agreement, update
168
+ LTX-2 through electronic means, or modify the Output of LTX-2
169
+ based on updates. You shall undertake reasonable efforts to use
170
+ the latest version of LTX-2. Any use of the non-current version
171
+ of LTX-2 is done solely at your risk.
172
+
173
+ 7. Export Controls and Sanctions Compliance. You acknowledge that
174
+ LTX-2, Derivatives of LTX-2 may be subject to export control laws
175
+ and regulations, including but not limited to the U.S. Export
176
+ Administration Regulations and sanctions programs administered by
177
+ the Office of Foreign Assets Control (OFAC). You represent and
178
+ warrant that you and any users of LTX-2 are not (i) located in,
179
+ organized under the laws of, or ordinarily resident in any country
180
+ or territory subject to comprehensive sanctions; (ii) identified
181
+ on any U.S. government restricted party list, including the
182
+ Specially Designated Nationals and Blocked Persons List; or
183
+ (iii) otherwise prohibited from receiving LTX-2 under applicable
184
+ law. You shall not export, re-export, or transfer LTX-2, directly
185
+ or indirectly, in violation of any applicable export control or
186
+ sanctions laws or regulations. You agree to comply with all
187
+ applicable trade control laws and shall indemnify and hold
188
+ Licensor harmless from any claims arising from your failure to
189
+ comply with such laws.
190
+
191
+ 8. Trademarks and related. Nothing in this Agreement permits you to
192
+ make use of Licensor's trademarks, trade names, logos or to
193
+ otherwise suggest endorsement or misrepresent the relationship
194
+ between the parties; and any rights not expressly granted herein
195
+ are reserved by the Licensor.
196
+
197
+ 9. Disclaimer of Warranty. Unless required by applicable law or
198
+ agreed to in writing, Licensor provides LTX-2 on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
200
+ implied, including, without limitation, any warranties or
201
+ conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS
202
+ FOR A PARTICULAR PURPOSE. You are solely responsible for
203
+ determining the appropriateness of using or redistributing LTX-2
204
+ and Derivatives of LTX-2 and assume any risks associated with
205
+ your exercise of permissions under this Agreement.
206
+
207
+ 10. Limitation of Liability. In no event and under no legal theory,
208
+ whether in tort (including negligence), contract, or otherwise,
209
+ unless required by applicable law (such as deliberate and grossly
210
+ negligent acts) or agreed to in writing, shall Licensor be liable
211
+ to you for damages, including any direct, indirect, special,
212
+ incidental, or consequential damages of any character arising as
213
+ a result of this Agreement or out of the use or inability to use
214
+ LTX-2 (including but not limited to damages for loss of goodwill,
215
+ work stoppage, computer failure or malfunction, or any and all
216
+ other commercial damages or losses), even if Licensor has been
217
+ advised of the possibility of such damages.
218
+
219
+ 11. Accepting Warranty or Additional Liability. While redistributing
220
+ LTX-2 and Derivatives of LTX-2, you may, provided you do not
221
+ violate the terms of this Agreement, choose to offer and charge
222
+ a fee for, acceptance of support, warranty, indemnity, or other
223
+ liability obligations. However, in accepting such obligations,
224
+ you may act only on your own behalf and on your sole
225
+ responsibility, not on behalf of Licensor, and only if you agree
226
+ to indemnify, defend, and hold Licensor harmless for any liability
227
+ incurred by, or claims asserted against Licensor, by reason of
228
+ your accepting any such warranty or additional liability.
229
+
230
+ 12. Governing Law. This Agreement and all relations, disputes, claims
231
+ and other matters arising hereunder (including non-contractual
232
+ disputes or claims) will be governed exclusively by, and construed
233
+ exclusively in accordance with, the laws of the State of New York.
234
+ To the extent permitted by law, choice of laws rules and the
235
+ United Nations Convention on Contracts for the International Sale
236
+ of Goods will not apply. For the purposes of adjudicating any
237
+ action or proceeding to enforce the terms of this Agreement, you
238
+ hereby irrevocably consent to the exclusive jurisdiction of, and
239
+ venue in, the federal and state courts located in the County of
240
+ New York within the State of New York. The prevailing party in
241
+ any claim or dispute between the parties under this Agreement
242
+ will be entitled to reimbursement of its reasonable attorneys'
243
+ fees and costs. You hereby waive the right to a trial by jury,
244
+ to participate in a class or representative action (including in
245
+ arbitration), or to combine individual proceedings in court or
246
+ in arbitration without the consent of all parties.
247
+
248
+ 13. Term and Termination. This Agreement is effective upon your
249
+ acceptance and continues until terminated. Licensor may terminate
250
+ this Agreement immediately upon written notice to you if you
251
+ breach any provision of this Agreement, including but not limited
252
+ to violations of the use restrictions in Attachment A or
253
+ unauthorized commercial use. Upon termination: (a) all rights
254
+ granted to you under this Agreement will immediately cease;
255
+ (b) you must immediately cease all use of LTX-2 and Derivatives
256
+ of LTX-2; (c) you must delete or destroy all copies of LTX-2
257
+ and Derivatives of LTX-2 in your possession or control; and
258
+ (d) you must notify any third parties to whom you distributed
259
+ LTX-2 or Derivatives of LTX-2 of the termination. Sections 8-13,
260
+ and Section 15 shall survive termination of this Agreement.
261
+ Termination does not relieve you of any obligations incurred
262
+ prior to termination, including payment obligations under
263
+ Section 2. In addition, if You commence a lawsuit or other
264
+ proceedings (including a cross-claim or counterclaim in a lawsuit)
265
+ against Licensor or any person or entity alleging that LTX-2 or
266
+ any Output, or any portion of any of the foregoing, infringe any
267
+ intellectual property or other right owned or licensable by you,
268
+ then all licenses granted to you under this Agreement shall
269
+ terminate as of the date such lawsuit or other proceeding is filed.
270
+
271
+ 14. Disputes and Arbitration. All disputes arising in connection with
272
+ this Agreement shall be finally settled by arbitration under the
273
+ Rules of Arbitration of the International Chamber of Commerce
274
+ ("ICC Rules"), by one (1) arbitrator appointed in accordance with
275
+ the ICC Rules. The seat of arbitration shall be New York, NY, USA,
276
+ and the proceedings shall be conducted in English. The arbitrator
277
+ shall be empowered to grant any relief that a court could grant.
278
+ Judgment on the arbitration award may be entered by any court
279
+ having jurisdiction thereof. Each party waives its right to a
280
+ trial by jury and to participate in any class or representative
281
+ action.
282
+
283
+ 15. If any provision of this Agreement is held to be
284
+ invalid, illegal
285
+ or unenforceable, the remaining provisions shall be unaffected
286
+ thereby and remain valid as if such provision had not been set
287
+ forth herein.
288
+
289
+ END OF TERMS AND CONDITIONS
290
+
291
+ ATTACHMENT A: Use Restrictions
292
+
293
+ When using the Outputs, LTX-2 and any Derivatives thereof, you
294
+ will comply with the Acceptable Use Policy. In addition, you
295
+ agree not to use the Outputs, LTX-2 or its Derivatives in any
296
+ of the following ways:
297
+
298
+ 1. In any way that violates any applicable national, federal,
299
+ state, local or international law or regulation;
300
+
301
+ 2. For the purpose of exploiting, Harming or attempting to
302
+ exploit or Harm minors in any way;
303
+
304
+ 3. To generate or disseminate false information and/or content
305
+ with the purpose of Harming others;
306
+
307
+ 4. To generate or disseminate personal identifiable information
308
+ that can be used to Harm an individual;
309
+
310
+ 5. To generate or disseminate information and/or content (e.g.
311
+ images, code, posts, articles), and place the information
312
+ and/or content in any context (e.g. bot generating tweets)
313
+ without expressly and intelligibly disclaiming that the
314
+ information and/or content is machine generated;
315
+
316
+ 6. To defame, disparage or otherwise harass others;
317
+
318
+ 7. To impersonate or attempt to impersonate (e.g. deepfakes)
319
+ others without their consent;
320
+
321
+ 8. For fully automated decision making that adversely impacts an
322
+ individual's legal rights or otherwise creates or modifies a
323
+ binding, enforceable obligation;
324
+
325
+ 9. For any use intended to or which has the effect of
326
+ discriminating against or Harming individuals or groups based
327
+ on online or offline social behavior or known or predicted
328
+ personal or personality characteristics;
329
+
330
+ 10. To exploit any of the vulnerabilities of a specific group of
331
+ persons based on their age, social, physical or mental
332
+ characteristics, in order to materially distort the behavior
333
+ of a person pertaining to that group in a manner that causes
334
+ or is likely to cause that person or another person physical
335
+ or psychological Harm;
336
+
337
+ 11. For any use intended to or which has the effect of
338
+ discriminating against individuals or groups based on legally
339
+ protected characteristics or categories;
340
+
341
+ 12. To provide medical advice and medical results interpretation;
342
+
343
+ 13. To generate or disseminate information for the purpose to be
344
+ used for administration of justice, law enforcement,
345
+ immigration or asylum processes, such as predicting an
346
+ individual will commit fraud/crime commitment (e.g. by text
347
+ profiling, drawing causal relationships between assertions
348
+ made in documents, indiscriminate and arbitrarily-targeted use);
349
+
350
+ 14. To generate and/or disseminate malware (including – but not
351
+ limited to – ransomware) or any other content to be used for
352
+ the purpose of harming electronic systems;
353
+
354
+ 15. To engage in, promote, incite, or facilitate discrimination
355
+ or other unlawful or harmful conduct in the provision of
356
+ employment, employment benefits, credit, housing, or other
357
+ essential goods and services;
358
+
359
+ 16. To engage in, promote, incite, or facilitate the harassment,
360
+ abuse, threatening, or bullying of individuals or groups of
361
+ individuals;
362
+
363
+ 17. For military, warfare, nuclear industries or applications,
364
+ weapons development, or any use in connection with activities
365
+ that may cause death, personal injury, or severe physical or
366
+ environmental damage;
367
+
368
+ 18. For commercial use only: To train, improve, or fine-tune any
369
+ other machine learning model, artificial intelligence system,
370
+ or competing model, except for Derivatives of LTX-2 as
371
+ expressly permitted under this Agreement;
372
+
373
+ 19. To circumvent, disable, or interfere with any technical
374
+ limitations, safety features, content filters, or use
375
+ restrictions implemented in LTX-2 by Licensor;
376
+
377
+ 20. To use LTX-2 or Derivatives of LTX-2 in any product, service,
378
+ or application that directly competes with Licensor's
379
+ commercial products or services, or is designed to replace or
380
+ substitute Licensor's offerings in the market, without
381
+ obtaining a separate commercial license from Licensor.
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LTX-2
2
+
3
+ [![Website](https://img.shields.io/badge/Website-LTX-181717?logo=google-chrome)](https://ltx.io)
4
+ [![Model](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)](https://huggingface.co/Lightricks/LTX-2.3)
5
+ [![Demo](https://img.shields.io/badge/Demo-Try%20Now-brightgreen?logo=vercel)](https://app.ltx.studio/ltx-2-playground/i2v)
6
+ [![Paper](https://img.shields.io/badge/Paper-PDF-EC1C24?logo=adobeacrobatreader&logoColor=white)](https://arxiv.org/abs/2601.03233)
7
+ [![Discord](https://img.shields.io/badge/Join-Discord-5865F2?logo=discord)](https://discord.gg/ltxplatform)
8
+
9
+ **LTX-2** is the first DiT-based audio-video foundation model that contains all core capabilities of modern video generation in one model: synchronized audio and video, high fidelity, multiple performance modes, production-ready outputs, API access, and open access.
10
+
11
+ <div align="center">
12
+ <video src="https://github.com/user-attachments/assets/4414adc0-086c-43de-b367-9362eeb20228" width="70%" poster=""> </video>
13
+ </div>
14
+
15
+ ## 🚀 Quick Start
16
+
17
+ ```bash
18
+ # Clone the repository
19
+ git clone https://github.com/Lightricks/LTX-2.git
20
+ cd LTX-2
21
+
22
+ # Set up the environment
23
+ uv sync --frozen
24
+ source .venv/bin/activate
25
+ ```
26
+
27
+ ### Required Models
28
+
29
+ Download the following models from the [LTX-2.3 HuggingFace repository](https://huggingface.co/Lightricks/LTX-2.3):
30
+
31
+ **LTX-2.3 Model Checkpoint** (choose and download one of the following)
32
+ * [`ltx-2.3-22b-dev.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-22b-dev.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-22b-dev.safetensors)
33
+ * [`ltx-2.3-22b-distilled.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-22b-distilled.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-22b-distilled.safetensors)
34
+
35
+ **Spatial Upscaler** - Required for current two-stage pipeline implementations in this repository
36
+ * [`ltx-2.3-spatial-upscaler-x2-1.0.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-spatial-upscaler-x2-1.0.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-spatial-upscaler-x2-1.0.safetensors)
37
+ * [`ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors)
38
+
39
+ **Temporal Upscaler** - Supported by the model and will be required for future pipeline implementations
40
+ * [`ltx-2.3-temporal-upscaler-x2-1.0.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-temporal-upscaler-x2-1.0.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-temporal-upscaler-x2-1.0.safetensors)
41
+
42
+ **Distilled LoRA** - Required for current two-stage pipeline implementations in this repository (except DistilledPipeline and ICLoraPipeline)
43
+ * [`ltx-2.3-22b-distilled-lora-384.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-22b-distilled-lora-384.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-22b-distilled-lora-384.safetensors)
44
+
45
+ **Gemma Text Encoder** (download all assets from the repository)
46
+ * [`Gemma 3`](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/tree/main)
47
+
48
+ **LoRAs**
49
+ * [`LTX-2.3-22b-IC-LoRA-Union-Control`](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control) - [Download](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control/resolve/main/ltx-2.3-22b-ic-lora-union-control-ref0.5.safetensors)
50
+ * [`LTX-2.3-22b-IC-LoRA-Inpainting`](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Inpainting) - [Download](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Inpainting/resolve/main/ltx-2.3-22b-ic-lora-inpainting.safetensors)
51
+ * [`LTX-2.3-22b-IC-LoRA-Motion-Track-Control`](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control) - [Download](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control/resolve/main/ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors)
52
+ * [`LTX-2-19b-IC-LoRA-Detailer`](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Detailer) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Detailer/resolve/main/ltx-2-19b-ic-lora-detailer.safetensors)
53
+ * [`LTX-2-19b-IC-LoRA-Pose-Control`](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Pose-Control) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Pose-Control/resolve/main/ltx-2-19b-ic-lora-pose-control.safetensors)
54
+ * [`LTX-2-19b-LoRA-Camera-Control-Dolly-In`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In/resolve/main/ltx-2-19b-lora-camera-control-dolly-in.safetensors)
55
+ * [`LTX-2-19b-LoRA-Camera-Control-Dolly-Left`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left/resolve/main/ltx-2-19b-lora-camera-control-dolly-left.safetensors)
56
+ * [`LTX-2-19b-LoRA-Camera-Control-Dolly-Out`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out/resolve/main/ltx-2-19b-lora-camera-control-dolly-out.safetensors)
57
+ * [`LTX-2-19b-LoRA-Camera-Control-Dolly-Right`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right/resolve/main/ltx-2-19b-lora-camera-control-dolly-right.safetensors)
58
+ * [`LTX-2-19b-LoRA-Camera-Control-Jib-Down`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down/resolve/main/ltx-2-19b-lora-camera-control-jib-down.safetensors)
59
+ * [`LTX-2-19b-LoRA-Camera-Control-Jib-Up`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up/resolve/main/ltx-2-19b-lora-camera-control-jib-up.safetensors)
60
+ * [`LTX-2-19b-LoRA-Camera-Control-Static`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static/resolve/main/ltx-2-19b-lora-camera-control-static.safetensors)
61
+
62
+ ### Available Pipelines
63
+
64
+ * **[TI2VidTwoStagesPipeline](packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py)** - Production-quality text/image-to-video with 2x upsampling (recommended)
65
+ * **[TI2VidTwoStagesHQPipeline](packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages_hq.py)** - Same two-stage flow as above but uses the res_2s second-order sampler (fewer steps, better quality)
66
+ * **[TI2VidOneStagePipeline](packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py)** - Single-stage generation for quick prototyping
67
+ * **[DistilledPipeline](packages/ltx-pipelines/src/ltx_pipelines/distilled.py)** - Fastest inference with 8 predefined sigmas
68
+ * **[ICLoraPipeline](packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py)** - Video-to-video and image-to-video transformations (uses distilled model.)
69
+ * **[KeyframeInterpolationPipeline](packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py)** - Interpolate between keyframe images
70
+ * **[A2VidPipelineTwoStage](packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py)** - Audio-to-video generation conditioned on an input audio file
71
+ * **[RetakePipeline](packages/ltx-pipelines/src/ltx_pipelines/retake.py)** - Regenerate a specific time region of an existing video
72
+
73
+ ### ⚡ Optimization Tips
74
+
75
+ * **Use DistilledPipeline** - Fastest inference with only 8 predefined sigmas (8 steps stage 1, 4 steps stage 2)
76
+ * **Enable FP8 quantization** - Enables lower memory footprint: `--quantization fp8-cast` (CLI) or `quantization=QuantizationPolicy.fp8_cast()` (Python). For Hopper GPUs with TensorRT-LLM, use `--quantization fp8-scaled-mm` for FP8 scaled matrix multiplication.
77
+ * **Install attention optimizations** - Use xFormers (`uv sync --extra xformers`) or [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) for Hopper GPUs
78
+ * **Use gradient estimation** - Reduce inference steps from 40 to 20-30 while maintaining quality (see [pipeline documentation](packages/ltx-pipelines/README.md#denoising-loop-optimization))
79
+ * **Skip memory cleanup** - If you have sufficient VRAM, disable automatic memory cleanup between stages for faster processing
80
+ * **Choose single-stage pipeline** - Use `TI2VidOneStagePipeline` for faster generation when high resolution isn't required
81
+
82
+ ## ✍️ Prompting for LTX-2
83
+
84
+ When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:
85
+
86
+ - Start with main action in a single sentence
87
+ - Add specific details about movements and gestures
88
+ - Describe character/object appearances precisely
89
+ - Include background and environment details
90
+ - Specify camera angles and movements
91
+ - Describe lighting and colors
92
+ - Note any changes or sudden events
93
+
94
+ For additional guidance on writing a prompt please refer to <https://ltx.video/blog/how-to-prompt-for-ltx-2>
95
+
96
+ ### Automatic Prompt Enhancement
97
+
98
+ LTX-2 pipelines support automatic prompt enhancement via an `enhance_prompt` parameter.
99
+
100
+ ## 🔌 ComfyUI Integration
101
+
102
+ To use our model with ComfyUI, please follow the instructions at <https://github.com/Lightricks/ComfyUI-LTXVideo/>.
103
+
104
+ ## 📦 Packages
105
+
106
+ This repository is organized as a monorepo with three main packages:
107
+
108
+ * **[ltx-core](packages/ltx-core/)** - Core model implementation, inference stack, and utilities
109
+ * **[ltx-pipelines](packages/ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and other generation modes
110
+ * **[ltx-trainer](packages/ltx-trainer/)** - Training and fine-tuning tools for LoRA, full fine-tuning, and IC-LoRA
111
+
112
+ Each package has its own README and documentation. See the [Documentation](#-documentation) section below.
113
+
114
+ ## 📚 Documentation
115
+
116
+ Each package includes comprehensive documentation:
117
+
118
+ * **[LTX-Core README](packages/ltx-core/README.md)** - Core model implementation, inference stack, and utilities
119
+ * **[LTX-Pipelines README](packages/ltx-pipelines/README.md)** - High-level pipeline implementations and usage guides
120
+ * **[LTX-Trainer README](packages/ltx-trainer/README.md)** - Training and fine-tuning documentation with detailed guides
packages/ltx-core/README.md ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LTX-Core
2
+
3
+ The foundational library for the LTX-2 Audio-Video generation model. This package contains the raw model definitions, component implementations, and loading logic used by `ltx-pipelines` and `ltx-trainer`.
4
+
5
+ ## 📦 What's Inside?
6
+
7
+ - **`components/`**: Modular diffusion components (Schedulers, Guiders, Noisers, Patchifiers) following standard protocols
8
+ - **`conditioning/`**: Tools for preparing latent states and applying conditioning (image, video, keyframes)
9
+ - **`guidance/`**: Perturbation system for fine-grained control over attention mechanisms
10
+ - **`loader/`**: Utilities for loading weights from `.safetensors`, fusing LoRAs, and managing memory
11
+ - **`model/`**: PyTorch implementations of the LTX-2 Transformer, Video VAE, Audio VAE, Vocoder and Upscaler
12
+ - **`text_encoders/gemma`**: Gemma text encoder implementation with tokenizers, feature extractors, and separate encoders for audio-video and video-only generation
13
+ - **`quantization/`**: FP8 quantization backends (FP8-TensorRT-LLM scaled MM, FP8 cast) for reduced memory footprint.
14
+
15
+ ## 🚀 Quick Start
16
+
17
+ `ltx-core` provides the building blocks (models, components, and utilities) needed to construct inference flows. For ready-made inference pipelines use [`ltx-pipelines`](../ltx-pipelines/) or [`ltx-trainer`](../ltx-trainer/) for training.
18
+
19
+ ## 🔧 Installation
20
+
21
+ ```bash
22
+ # From the repository root
23
+ uv sync --frozen
24
+
25
+ # Or install as a package
26
+ pip install -e packages/ltx-core
27
+ ```
28
+
29
+ ## Building Blocks Overview
30
+
31
+ `ltx-core` provides modular components that can be combined to build custom inference flows:
32
+
33
+ ### Core Models
34
+
35
+ - **Transformer** ([`model/transformer/`](src/ltx_core/model/transformer/)): The asymmetric dual-stream LTX-2 transformer (14B-parameter video stream, 5B-parameter audio stream) with bidirectional cross-modal attention for joint audio-video processing. Expects inputs in [`Modality`](src/ltx_core/model/transformer/modality.py) format
36
+ - **Video VAE** ([`model/video_vae/`](src/ltx_core/model/video_vae/)): Encodes/decodes video pixels to/from latent space with temporal and spatial compression
37
+ - **Audio VAE** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Encodes/decodes audio spectrograms to/from latent space
38
+ - **Vocoder** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Neural vocoder that converts mel spectrograms to audio waveforms
39
+ - **Text Encoder** ([`text_encoders/`](src/ltx_core/text_encoders/)): Gemma 3-based multilingual encoder with multi-layer feature extraction and thinking tokens that produces separate embeddings for video and audio conditioning
40
+ - **Spatial Upscaler** ([`model/upsampler/`](src/ltx_core/model/upsampler/)): Upsamples latent representations for higher-resolution generation
41
+
42
+ ### Diffusion Components
43
+
44
+ - **Schedulers** ([`components/schedulers.py`](src/ltx_core/components/schedulers.py)): Noise schedules (LTX2Scheduler, LinearQuadratic, Beta) that control the denoising process
45
+ - **Guiders** ([`components/guiders.py`](src/ltx_core/components/guiders.py)): Guidance strategies (CFG, STG, APG) for controlling generation quality and adherence to prompts
46
+ - **Noisers** ([`components/noisers.py`](src/ltx_core/components/noisers.py)): Add noise to latents according to the diffusion schedule
47
+ - **Patchifiers** ([`components/patchifiers.py`](src/ltx_core/components/patchifiers.py)): Convert between spatial latents `[B, C, F, H, W]` and sequence format `[B, seq_len, dim]` for transformer processing
48
+
49
+ ### Conditioning & Control
50
+
51
+ - **Conditioning** ([`conditioning/`](src/ltx_core/conditioning/)): Tools for preparing and applying various conditioning types (image, video, keyframes)
52
+ - **Guidance** ([`guidance/`](src/ltx_core/guidance/)): Perturbation system for fine-grained control over attention mechanisms (e.g., skipping specific attention layers)
53
+
54
+ ### Utilities
55
+
56
+ - **Loader** ([`loader/`](src/ltx_core/loader/)): Model loading from `.safetensors`, LoRA fusion, weight remapping, and memory management
57
+ - **Quantization** ([`quantization/`](src/ltx_core/quantization/)): FP8 quantization backends for reduced memory footprint and faster inference
58
+
59
+ ### Loader
60
+
61
+ The `loader/` module provides `SingleGPUModelBuilder`, a frozen dataclass that loads a PyTorch model from `.safetensors` checkpoints and optionally fuses one or more LoRA adapters.
62
+
63
+ #### Basic usage
64
+
65
+ ```python
66
+ from ltx_core.loader import SingleGPUModelBuilder
67
+
68
+ builder = SingleGPUModelBuilder(
69
+ model_class_configurator=MyModelConfigurator,
70
+ model_path="/path/to/model.safetensors",
71
+ )
72
+ model = builder.build(device=torch.device("cuda"))
73
+ ```
74
+
75
+ #### Loading LoRA adapters
76
+
77
+ Use the `.lora()` method to attach one or more LoRA adapters before calling `.build()`:
78
+
79
+ ```python
80
+ builder = (
81
+ SingleGPUModelBuilder(
82
+ model_class_configurator=MyModelConfigurator,
83
+ model_path="/path/to/model.safetensors",
84
+ )
85
+ .lora("/path/to/lora_a.safetensors", strength=0.8)
86
+ .lora("/path/to/lora_b.safetensors", strength=0.5)
87
+ )
88
+ model = builder.build(device=torch.device("cuda"))
89
+ ```
90
+
91
+ #### Memory-efficient LoRA loading (`lora_load_device`)
92
+
93
+ By default, LoRA weights are loaded onto the **CPU** (`lora_load_device=torch.device("cpu")`). This means each LoRA adapter is kept in CPU memory and transferred to the GPU sequentially during weight fusion, which keeps peak GPU memory low even when fusing large adapters.
94
+
95
+ If all adapters fit comfortably in GPU memory you can skip the CPU staging by setting `lora_load_device` to the target CUDA device:
96
+
97
+ ```python
98
+ import torch
99
+ from ltx_core.loader import SingleGPUModelBuilder
100
+
101
+ # Load LoRA weights directly onto the GPU (faster, but uses more GPU memory)
102
+ builder = SingleGPUModelBuilder(
103
+ model_class_configurator=MyModelConfigurator,
104
+ model_path="/path/to/model.safetensors",
105
+ lora_load_device=torch.device("cuda"),
106
+ ).lora("/path/to/lora.safetensors", strength=1.0)
107
+
108
+ model = builder.build(device=torch.device("cuda"))
109
+ ```
110
+
111
+ ### Quantization
112
+
113
+ The `quantization/` module provides FP8 quantization support for the LTX-2 transformer, significantly reducing memory usage while maintaining quality. Two backends are available:
114
+
115
+ #### FP8 Scaled MM (TensorRT-LLM)
116
+
117
+ Uses NVIDIA TensorRT-LLM's `cublas_scaled_mm` for efficient FP8 matrix multiplication. Weights are stored in FP8 format with per-tensor scaling, and inputs are quantized dynamically (or statically with calibration data).
118
+
119
+ **Requirements**: `uv sync --frozen --extra fp8-trtllm`
120
+
121
+ **Usage with QuantizationPolicy:**
122
+
123
+ ```python
124
+ from ltx_core.quantization import QuantizationPolicy
125
+
126
+ # Dynamic input quantization (no calibration needed)
127
+ policy = QuantizationPolicy.fp8_scaled_mm()
128
+
129
+ # Static input quantization with calibration file
130
+ policy = QuantizationPolicy.fp8_scaled_mm(calibration_amax_path="/path/to/amax.json")
131
+ ```
132
+
133
+ The policy provides `sd_ops` and `module_ops` that can be passed to the model builder:
134
+
135
+ ```python
136
+ from ltx_core.loader import SingleGPUModelBuilder
137
+
138
+ builder = SingleGPUModelBuilder(
139
+ model=model,
140
+ device=device,
141
+ sd_ops=policy.sd_ops,
142
+ module_ops=policy.module_ops,
143
+ )
144
+ builder.load(checkpoint_path)
145
+ ```
146
+
147
+ **Calibration File Format** (for static input quantization):
148
+
149
+ ```json
150
+ {
151
+ "amax_values": {
152
+ "transformer_blocks.0.attn.to_q.input_quantizer": 12.5,
153
+ "transformer_blocks.0.attn.to_k.input_quantizer": 8.3,
154
+ ...
155
+ }
156
+ }
157
+ ```
158
+
159
+ #### FP8 Cast
160
+
161
+ A simpler approach that casts weights to FP8 for storage and upcasts during inference:
162
+
163
+ ```python
164
+ policy = QuantizationPolicy.fp8_cast()
165
+ ```
166
+
167
+ For complete, production-ready pipeline implementations that combine these building blocks, see the [`ltx-pipelines`](../ltx-pipelines/) package.
168
+
169
+ ---
170
+
171
+ # Architecture Overview
172
+
173
+ This section provides a deep dive into the internal architecture of the LTX-2 Audio-Video generation model.
174
+
175
+ ## Table of Contents
176
+
177
+ 1. [High-Level Architecture](#high-level-architecture)
178
+ 2. [The Transformer](#the-transformer)
179
+ 3. [Video VAE](#video-vae)
180
+ 4. [Audio VAE](#audio-vae)
181
+ 5. [Text Encoding (Gemma)](#text-encoding-gemma)
182
+ 6. [Spatial Upscaler](#spatial-upsampler)
183
+ 7. [Data Flow](#data-flow)
184
+
185
+ ---
186
+
187
+ ## High-Level Architecture
188
+
189
+ LTX-2 is an **asymmetric dual-stream diffusion transformer** that jointly models the text-conditioned distribution of video and audio signals, capturing true joint dependencies (unlike sequential T2V→V2A pipelines).
190
+
191
+ ### Key Design Principles
192
+
193
+ - **Decoupled Latent Representations**: Separate modality-specific VAEs enable 3D RoPE (video) vs 1D RoPE (audio), independent compression optimization, and native V2A/A2V editing workflows
194
+ - **Asymmetric Dual-Stream**: 14B-parameter video stream (spatiotemporal dynamics) + 5B-parameter audio stream (1D temporal), sharing 48 transformer blocks but differing in width
195
+ - **Bidirectional Cross-Modal Attention**: 1D temporal RoPE enables sub-frame alignment, mapping visual cues to auditory events (lip-sync, foley, environmental acoustics)
196
+ - **Cross-Modality AdaLN**: Scaling/shift parameters conditioned on the other modality's hidden states for synchronization across differing diffusion timesteps/temporal resolutions
197
+
198
+ ```text
199
+ ┌─────────────────────────────────────────────────────────────┐
200
+ │ INPUT PREPARATION │
201
+ │ │
202
+ │ Video Pixels → Video VAE Encoder → Video Latents │
203
+ │ Audio Waveform → Audio VAE Encoder → Audio Latents │
204
+ │ Text Prompt → Gemma 3 Encoder → Text Embeddings │
205
+ └─────────────────────────────────────────────────────────────┘
206
+
207
+ ┌─────────────────────────────────────────────────────────────┐
208
+ │ LTX-2 ASYMMETRIC DUAL-STREAM TRANSFORMER (48 Blocks) │
209
+ │ │
210
+ │ ┌──────────────────────┐ ┌──────────────────────┐ │
211
+ │ │ Video Stream (14B) │ │ Audio Stream (5B) │ │
212
+ │ │ │ │ │ │
213
+ │ │ 3D RoPE (x,y,t) │ │ 1D RoPE (temporal) │ │
214
+ │ │ │ │ │ │
215
+ │ │ Self-Attn │ │ Self-Attn │ │
216
+ │ │ Text Cross-Attn │ │ Text Cross-Attn │ │
217
+ │ │ │◄────►│ │ │
218
+ │ │ A↔V Cross-Attn │ │ A↔V Cross-Attn │ │
219
+ │ │ (1D temporal RoPE) │ │ (1D temporal RoPE) │ │
220
+ │ │ Cross-modality │ │ Cross-modality │ │
221
+ │ │ AdaLN │ │ AdaLN │ │
222
+ │ │ Feed-Forward │ │ Feed-Forward │ │
223
+ │ └──────────────────────┘ └──────────────────────┘ │
224
+ └─────────────────────────────────────────────────────────────┘
225
+
226
+ ┌─────────────────────────────────────────────────────────────┐
227
+ │ OUTPUT DECODING │
228
+ │ │
229
+ │ Video Latents → Video VAE Decoder → Video Pixels │
230
+ │ Audio Latents → Audio VAE Decoder → Mel Spectrogram │
231
+ │ Mel Spectrogram → Vocoder → Audio Waveform (24 kHz) │
232
+ └─────────────────────────────────────────────────────────────┘
233
+ ```
234
+
235
+ ---
236
+
237
+ ## The Transformer
238
+
239
+ The core of LTX-2 is an **asymmetric dual-stream diffusion transformer** with 48 layers that processes both video and audio tokens simultaneously. The architecture allocates 14B parameters to the video stream and 5B parameters to the audio stream, reflecting the different information densities of the two modalities.
240
+
241
+ ### Model Structure
242
+
243
+ **Source**: [`src/ltx_core/model/transformer/model.py`](src/ltx_core/model/transformer/model.py)
244
+
245
+ The `LTXModel` class implements the transformer. It supports both video-only and audio-video generation modes. For actual usage, see the [`ltx-pipelines`](../ltx-pipelines/) package which handles model loading and initialization.
246
+
247
+ ### Transformer Block Architecture
248
+
249
+ **Source**: [`src/ltx_core/model/transformer/transformer.py`](src/ltx_core/model/transformer/transformer.py)
250
+
251
+ Each dual-stream block performs four operations sequentially:
252
+
253
+ 1. **Self-Attention**: Within-modality attention for each stream
254
+ 2. **Text Cross-Attention**: Textual prompt conditioning for both streams
255
+ 3. **Audio-Visual Cross-Attention**: Bidirectional inter-modal exchange
256
+ 4. **Feed-Forward Network (FFN)**: Feature refinement
257
+
258
+ ```text
259
+ ┌─────────────────────────────────────────────────────────────┐
260
+ │ TRANSFORMER BLOCK │
261
+ │ │
262
+ │ VIDEO (14B): Input → RMSNorm → AdaLN → Self-Attn → │
263
+ │ RMSNorm → Text Cross-Attn → │
264
+ │ RMSNorm → AdaLN → A↔V Cross-Attn (1D RoPE) → │
265
+ │ RMSNorm → AdaLN → FFN → Output │
266
+ │ │
267
+ │ AUDIO (5B): Input → RMSNorm → AdaLN → Self-Attn → │
268
+ │ RMSNorm → Text Cross-Attn → │
269
+ │ RMSNorm → AdaLN → A↔V Cross-Attn (1D RoPE) → │
270
+ │ RMSNorm → AdaLN → FFN → Output │
271
+ │ │
272
+ │ RoPE: Video=3D (x,y,t), Audio=1D (t), Cross-Attn=1D (t) │
273
+ │ AdaLN: Timestep-conditioned, cross-modality for A↔V CA │
274
+ └─────────────────────────────────────────────────────────────┘
275
+ ```
276
+
277
+ ### Audio-Visual Cross-Attention Details
278
+
279
+ Bidirectional cross-attention enables tight temporal alignment: video and audio streams exchange information bidirectionally using 1D temporal RoPE (synchronization only, no spatial alignment). AdaLN gates condition on each modality's timestep for cross-modal synchronization.
280
+
281
+ ### Perturbations
282
+
283
+ The transformer supports [**perturbations**](src/ltx_core/guidance/perturbations.py) that selectively skip attention operations.
284
+
285
+ Perturbations allow you to disable specific attention mechanisms during inference, which is useful for guidance techniques like STG (Spatio-Temporal Guidance).
286
+
287
+ **Supported Perturbation Types**:
288
+
289
+ - `SKIP_VIDEO_SELF_ATTN`: Skip video self-attention
290
+ - `SKIP_AUDIO_SELF_ATTN`: Skip audio self-attention
291
+ - `SKIP_A2V_CROSS_ATTN`: Skip audio-to-video cross-attention
292
+ - `SKIP_V2A_CROSS_ATTN`: Skip video-to-audio cross-attention
293
+
294
+ Perturbations are used internally by guidance mechanisms like STG (Spatio-Temporal Guidance). For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
295
+
296
+ ---
297
+
298
+ ## Video VAE
299
+
300
+ The Video VAE ([`src/ltx_core/model/video_vae/`](src/ltx_core/model/video_vae/)) encodes video pixels into latent representations and decodes them back.
301
+
302
+ ### Architecture
303
+
304
+ - **Encoder**: Compresses `[B, 3, F, H, W]` pixels → `[B, 128, F', H/32, W/32]` latents
305
+ - Where `F' = 1 + (F-1)/8` (frame count must satisfy `(F-1) % 8 == 0`)
306
+ - Example: `[B, 3, 33, 512, 512]` → `[B, 128, 5, 16, 16]`
307
+ - **Decoder**: Expands `[B, 128, F, H, W]` latents → `[B, 3, F', H*32, W*32]` pixels
308
+ - Where `F' = 1 + (F-1)*8`
309
+ - Example: `[B, 128, 5, 16, 16]` → `[B, 3, 33, 512, 512]`
310
+
311
+ The Video VAE is used internally by pipelines for encoding video pixels to latents and decoding latents back to pixels. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
312
+
313
+ ---
314
+
315
+ ## Audio VAE
316
+
317
+ The Audio VAE ([`src/ltx_core/model/audio_vae/`](src/ltx_core/model/audio_vae/)) processes audio spectrograms.
318
+
319
+ ### Audio VAE Architecture
320
+
321
+ Compact neural audio representation optimized for diffusion-based training. Natively supports stereo: processes two-channel mel-spectrograms (16 kHz input) with channel concatenation before encoding.
322
+
323
+ - **Encoder**: `[B, mel_bins, T]` → `[B, 8, T/4, 16]` latents (4× temporal downsampling, 8 channels, 16 mel bins in latent space, ~1/25s per token, 128-dim feature vector)
324
+ - **Decoder**: `[B, 8, T, 16]` → `[B, mel_bins, T*4]` mel spectrogram
325
+ - **Vocoder**: HiFi-GAN-based, modified for stereo synthesis and upsampling (16 kHz mel → 24 kHz waveform, doubled generator capacity for stereo)
326
+
327
+ **Downsampling**:
328
+
329
+ - Temporal: 4× (time steps)
330
+ - Frequency: Variable (input mel_bins → fixed 16 in latent space)
331
+
332
+ The Audio VAE is used internally by pipelines for encoding mel spectrograms to latents and decoding latents back to mel spectrograms. The vocoder converts mel spectrograms to audio waveforms. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
333
+
334
+ ---
335
+
336
+ ## Text Encoding (Gemma)
337
+
338
+ LTX-2 uses **Gemma 3** (Gemma 3-12B) as the multilingual text encoder backbone, located in [`src/ltx_core/text_encoders/gemma/`](src/ltx_core/text_encoders/gemma/). Advanced text understanding is critical not only for global language support but for the phonetic and semantic accuracy of generated speech.
339
+
340
+ ### Text Encoder Architecture
341
+
342
+ The text conditioning pipeline consists of three stages:
343
+
344
+ 1. **Gemma 3 Backbone**: Decoder-only LLM processes text tokens → embeddings across all layers `[B, T, D, L]`
345
+ 2. **Multi-Layer Feature Extractor**: Aggregates features from all decoder layers (not just final layer), applies mean-centered scaling, flattens to `[B, T, D×L]`, and projects via learnable matrix W (jointly optimized with LTX-2, LLM weights frozen)
346
+ 3. **Text Connector**: Bidirectional transformer blocks with learnable registers (replacing padded positions, also referred to as "thinking tokens" in the paper) for contextual mixing. Separate connectors for video and audio streams (`Embeddings1DConnector`)
347
+
348
+ **Encoders**:
349
+
350
+ - `AVGemmaTextEncoderModel`: Audio-video generation (two connectors → `AVGemmaEncoderOutput` with separate video/audio contexts)
351
+ - `VideoGemmaTextEncoderModel`: Video-only generation (single connector → `VideoGemmaEncoderOutput`)
352
+
353
+ ### System Prompts
354
+
355
+ System prompts are also used to enhance user's prompts.
356
+
357
+ - **Text-to-Video**: [`gemma_t2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt)
358
+ - **Image-to-Video**: [`gemma_i2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt)
359
+
360
+ **Important**: Video and audio receive **different** context embeddings, even from the same prompt. This allows better modality-specific conditioning and enables the model to synthesize speech that is synchronized with visual lip movement while being natural in cadence, accent, and emotional tone.
361
+
362
+ **Output Format**:
363
+
364
+ - Video context: `[B, seq_len, 4096]` - Video-specific text embeddings
365
+ - Audio context: `[B, seq_len, 2048]` - Audio-specific text embeddings
366
+
367
+ The text encoder is used internally by pipelines. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
368
+
369
+ ---
370
+
371
+ ## Upscaler
372
+
373
+ The Upscaler ([`src/ltx_core/model/upsampler/`](src/ltx_core/model/upsampler/)) upsamples latent representations for higher-resolution output.
374
+
375
+ The spatial upsampler is used internally by two-stage pipelines (e.g., [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py), [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py)) to upsample low-resolution latents before final VAE decoding. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
376
+
377
+ ---
378
+
379
+ ## Data Flow
380
+
381
+ ### Complete Generation Pipeline
382
+
383
+ Here's how all the components work together conceptually ([`src/ltx_core/components/`](src/ltx_core/components/)):
384
+
385
+ **Pipeline Steps**:
386
+
387
+ 1. **Text Encoding**: Text prompt → Gemma encoder → separate video/audio embeddings
388
+ 2. **Latent Initialization**: Initialize noise latents in spatial format `[B, C, F, H, W]`
389
+ 3. **Patchification**: Convert spatial latents to sequence format `[B, seq_len, dim]` for transformer
390
+ 4. **Sigma Schedule**: Generate noise schedule (adapts to token count)
391
+ 5. **Denoising Loop**: Iteratively denoise using transformer predictions
392
+ - Create Modality inputs with per-token timesteps and RoPE positions
393
+ - Forward pass through transformer (conditional and unconditional for CFG)
394
+ - Apply guidance (CFG, STG, etc.)
395
+ - Update latents using diffusion step (Euler, etc.)
396
+ 6. **Unpatchification**: Convert sequence back to spatial format
397
+ 7. **VAE Decoding**: Decode latents to pixel space (with optional upsampling for two-stage)
398
+
399
+ - [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py) - Two-stage text-to-video (recommended)
400
+ - [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py) - Video-to-video with IC-LoRA control
401
+ - [`DistilledPipeline`](../ltx-pipelines/src/ltx_pipelines/distilled.py) - Fast inference with distilled model
402
+ - [`KeyframeInterpolationPipeline`](../ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py) - Keyframe-based interpolation
403
+
404
+ See the [ltx-pipelines README](../ltx-pipelines/README.md) for usage examples.
405
+
406
+ ## 🔗 Related Projects
407
+
408
+ - **[ltx-pipelines](../ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and video-to-video
409
+ - **[ltx-trainer](../ltx-trainer/)** - Training and fine-tuning tools
packages/ltx-core/src/ltx_core/conditioning/types/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning type implementations."""
2
+
3
+ from ltx_core.conditioning.types.attention_strength_wrapper import ConditioningItemAttentionStrengthWrapper
4
+ from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
5
+ from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
6
+ from ltx_core.conditioning.types.reference_video_cond import VideoConditionByReferenceLatent
7
+
8
+ __all__ = [
9
+ "ConditioningItemAttentionStrengthWrapper",
10
+ "VideoConditionByKeyframeIndex",
11
+ "VideoConditionByLatentIndex",
12
+ "VideoConditionByReferenceLatent",
13
+ ]
packages/ltx-core/src/ltx_core/conditioning/types/attention_strength_wrapper.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper conditioning item that adds attention masking to any inner conditioning."""
2
+
3
+ from dataclasses import replace
4
+
5
+ import torch
6
+
7
+ from ltx_core.conditioning.item import ConditioningItem
8
+ from ltx_core.conditioning.mask_utils import update_attention_mask
9
+ from ltx_core.tools import LatentTools
10
+ from ltx_core.types import LatentState
11
+
12
+
13
+ class ConditioningItemAttentionStrengthWrapper(ConditioningItem):
14
+ """Wraps a conditioning item to add an attention mask for its tokens.
15
+ Separates the *attention-masking* concern from the underlying conditioning
16
+ logic (token layout, positional encoding, denoise strength). The inner
17
+ conditioning item appends tokens to the latent sequence as usual, and this
18
+ wrapper then builds or updates the self-attention mask so that the newly
19
+ added tokens interact with the noisy tokens according to *attention_mask*.
20
+ Args:
21
+ conditioning: Any conditioning item that appends tokens to the latent.
22
+ attention_mask: Per-token attention weight controlling how strongly the
23
+ new conditioning tokens attend to/from noisy tokens. Can be a
24
+ scalar (float) applied uniformly, or a tensor of shape ``(B, M)``
25
+ for spatial control, where ``M = F * H * W`` is the number of
26
+ patchified conditioning tokens. Values in ``[0, 1]``.
27
+ Example::
28
+ cond = ConditioningItemAttentionStrengthWrapper(
29
+ VideoConditionByReferenceLatent(latent=ref, strength=1.0),
30
+ attention_mask=0.5,
31
+ )
32
+ state = cond.apply_to(latent_state, latent_tools)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ conditioning: ConditioningItem,
38
+ attention_mask: float | torch.Tensor,
39
+ ):
40
+ self.conditioning = conditioning
41
+ self.attention_mask = attention_mask
42
+
43
+ def apply_to(
44
+ self,
45
+ latent_state: LatentState,
46
+ latent_tools: LatentTools,
47
+ ) -> LatentState:
48
+ """Apply inner conditioning, then build the attention mask for its tokens."""
49
+ # Snapshot the original state for mask building
50
+ original_state = latent_state
51
+
52
+ # Inner conditioning appends tokens (positions, denoise mask, etc.)
53
+ new_state = self.conditioning.apply_to(latent_state, latent_tools)
54
+
55
+ num_new_tokens = new_state.latent.shape[1] - original_state.latent.shape[1]
56
+ if num_new_tokens == 0:
57
+ return new_state
58
+
59
+ # Build the attention mask using the *original* state as the reference
60
+ # so that the block structure is computed correctly.
61
+ new_attention_mask = update_attention_mask(
62
+ latent_state=original_state,
63
+ attention_mask=self.attention_mask,
64
+ num_noisy_tokens=latent_tools.target_shape.token_count(),
65
+ num_new_tokens=num_new_tokens,
66
+ batch_size=new_state.latent.shape[0],
67
+ device=new_state.latent.device,
68
+ dtype=new_state.latent.dtype,
69
+ )
70
+
71
+ return replace(new_state, attention_mask=new_attention_mask)
packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.patchifiers import get_pixel_coords
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.conditioning.mask_utils import update_attention_mask
6
+ from ltx_core.tools import VideoLatentTools
7
+ from ltx_core.types import LatentState, VideoLatentShape
8
+
9
+
10
+ class VideoConditionByKeyframeIndex(ConditioningItem):
11
+ """
12
+ Conditions video generation on keyframe latents at a specific frame index.
13
+ Appends keyframe tokens to the latent state with positions offset by frame_idx,
14
+ and sets denoise strength according to the strength parameter.
15
+ To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
16
+ Args:
17
+ keyframes: Keyframe latents [B, C, F, H, W].
18
+ frame_idx: Frame index offset for positional encoding.
19
+ strength: Conditioning strength (1.0 = clean, 0.0 = fully denoised).
20
+ """
21
+
22
+ def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
23
+ self.keyframes = keyframes
24
+ self.frame_idx = frame_idx
25
+ self.strength = strength
26
+
27
+ def apply_to(
28
+ self,
29
+ latent_state: LatentState,
30
+ latent_tools: VideoLatentTools,
31
+ ) -> LatentState:
32
+ tokens = latent_tools.patchifier.patchify(self.keyframes)
33
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
34
+ output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
35
+ device=self.keyframes.device,
36
+ )
37
+ positions = get_pixel_coords(
38
+ latent_coords=latent_coords,
39
+ scale_factors=latent_tools.scale_factors,
40
+ causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
41
+ )
42
+
43
+ positions[:, 0, ...] += self.frame_idx
44
+ positions = positions.to(dtype=torch.float32)
45
+ positions[:, 0, ...] /= latent_tools.fps
46
+
47
+ denoise_mask = torch.full(
48
+ size=(*tokens.shape[:2], 1),
49
+ fill_value=1.0 - self.strength,
50
+ device=self.keyframes.device,
51
+ dtype=self.keyframes.dtype,
52
+ )
53
+
54
+ new_attention_mask = update_attention_mask(
55
+ latent_state=latent_state,
56
+ attention_mask=None,
57
+ num_noisy_tokens=latent_tools.target_shape.token_count(),
58
+ num_new_tokens=tokens.shape[1],
59
+ batch_size=tokens.shape[0],
60
+ device=self.keyframes.device,
61
+ dtype=self.keyframes.dtype,
62
+ )
63
+
64
+ return LatentState(
65
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
66
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
67
+ positions=torch.cat([latent_state.positions, positions], dim=2),
68
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
69
+ attention_mask=new_attention_mask,
70
+ )
packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import LatentTools
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class VideoConditionByLatentIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation by injecting latents at a specific latent frame index.
12
+ Replaces tokens in the latent state at positions corresponding to latent_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
17
+ self.latent = latent
18
+ self.strength = strength
19
+ self.latent_idx = latent_idx
20
+
21
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
22
+ cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
23
+ tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
24
+
25
+ if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
26
+ raise ConditioningError(
27
+ f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
28
+ f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
29
+ "the image and latent have the same spatial shape."
30
+ )
31
+
32
+ tokens = latent_tools.patchifier.patchify(self.latent)
33
+ start_token = latent_tools.patchifier.get_token_count(
34
+ latent_tools.target_shape._replace(frames=self.latent_idx)
35
+ )
36
+ stop_token = start_token + tokens.shape[1]
37
+
38
+ latent_state = latent_state.clone()
39
+
40
+ latent_state.latent[:, start_token:stop_token] = tokens
41
+ latent_state.clean_latent[:, start_token:stop_token] = tokens
42
+ latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
43
+
44
+ return latent_state
packages/ltx-core/src/ltx_core/conditioning/types/reference_video_cond.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reference video conditioning for IC-LoRA inference."""
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.patchifiers import get_pixel_coords
6
+ from ltx_core.conditioning.item import ConditioningItem
7
+ from ltx_core.conditioning.mask_utils import update_attention_mask
8
+ from ltx_core.tools import VideoLatentTools
9
+ from ltx_core.types import LatentState, VideoLatentShape
10
+
11
+
12
+ class VideoConditionByReferenceLatent(ConditioningItem):
13
+ """
14
+ Conditions video generation on a reference video latent for IC-LoRA inference.
15
+ IC-LoRAs are trained by concatenating reference (control signal) and target tokens,
16
+ learning to attend across both. This class replicates that setup at inference by
17
+ appending reference tokens to the latent sequence.
18
+ IC-LoRAs can be trained with lower-resolution references than the target (e.g., 384px
19
+ reference for 768px output) for efficiency and better generalization. The
20
+ `downscale_factor` scales reference positions to match target coordinates, preserving
21
+ the learned positional relationships. This must match the factor used during training
22
+ (stored in LoRA metadata).
23
+ To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
24
+ Args:
25
+ latent: Reference video latents [B, C, F, H, W]
26
+ downscale_factor: Target/reference resolution ratio (e.g., 2 = half-resolution
27
+ reference). Spatial positions are scaled by this factor.
28
+ strength: Conditioning strength. 1.0 = full (reference kept clean),
29
+ 0.0 = none (reference denoised). Default 1.0.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ latent: torch.Tensor,
35
+ downscale_factor: int = 1,
36
+ strength: float = 1.0,
37
+ ):
38
+ self.latent = latent
39
+ self.downscale_factor = downscale_factor
40
+ self.strength = strength
41
+
42
+ def apply_to(
43
+ self,
44
+ latent_state: LatentState,
45
+ latent_tools: VideoLatentTools,
46
+ ) -> LatentState:
47
+ """Append reference video tokens with scaled positions."""
48
+ tokens = latent_tools.patchifier.patchify(self.latent)
49
+
50
+ # Compute positions for the reference video's actual dimensions
51
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
52
+ output_shape=VideoLatentShape.from_torch_shape(self.latent.shape),
53
+ device=self.latent.device,
54
+ )
55
+ positions = get_pixel_coords(
56
+ latent_coords=latent_coords,
57
+ scale_factors=latent_tools.scale_factors,
58
+ causal_fix=latent_tools.causal_fix,
59
+ )
60
+ positions = positions.to(dtype=torch.float32)
61
+ positions[:, 0, ...] /= latent_tools.fps
62
+
63
+ # Scale spatial positions to match target coordinate space
64
+ if self.downscale_factor != 1:
65
+ positions[:, 1, ...] *= self.downscale_factor # height axis
66
+ positions[:, 2, ...] *= self.downscale_factor # width axis
67
+
68
+ denoise_mask = torch.full(
69
+ size=(*tokens.shape[:2], 1),
70
+ fill_value=1.0 - self.strength,
71
+ device=self.latent.device,
72
+ dtype=self.latent.dtype,
73
+ )
74
+
75
+ new_attention_mask = update_attention_mask(
76
+ latent_state=latent_state,
77
+ attention_mask=None,
78
+ num_noisy_tokens=latent_tools.target_shape.token_count(),
79
+ num_new_tokens=tokens.shape[1],
80
+ batch_size=tokens.shape[0],
81
+ device=self.latent.device,
82
+ dtype=self.latent.dtype,
83
+ )
84
+
85
+ return LatentState(
86
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
87
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
88
+ positions=torch.cat([latent_state.positions, positions], dim=2),
89
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
90
+ attention_mask=new_attention_mask,
91
+ )
packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (454 Bytes). View file
 
packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-312.pyc ADDED
Binary file (5.7 kB). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class CausalityAxis(Enum):
5
+ """Enum for specifying the causality axis in causal convolutions."""
6
+
7
+ NONE = None
8
+ WIDTH = "width"
9
+ HEIGHT = "height"
10
+ WIDTH_COMPATIBILITY = "width-compatibility"
packages/ltx-core/src/ltx_core/model/common/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Common model utilities."""
2
+
3
+ from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
4
+
5
+ __all__ = [
6
+ "NormType",
7
+ "PixelNorm",
8
+ "build_normalization_layer",
9
+ ]
packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (378 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-312.pyc ADDED
Binary file (3.18 kB). View file
 
packages/ltx-core/src/ltx_core/model/common/normalization.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class NormType(Enum):
8
+ """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
9
+
10
+ GROUP = "group"
11
+ PIXEL = "pixel"
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ """
16
+ Per-pixel (per-location) RMS normalization layer.
17
+ For each element along the chosen dimension, this layer normalizes the tensor
18
+ by the root-mean-square of its values across that dimension:
19
+ y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
20
+ """
21
+
22
+ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
23
+ """
24
+ Args:
25
+ dim: Dimension along which to compute the RMS (typically channels).
26
+ eps: Small constant added for numerical stability.
27
+ """
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.eps = eps
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Apply RMS normalization along the configured dimension.
35
+ """
36
+ # Compute mean of squared values along `dim`, keep dimensions for broadcasting.
37
+ mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
38
+ # Normalize by the root-mean-square (RMS).
39
+ rms = torch.sqrt(mean_sq + self.eps)
40
+ return x / rms
41
+
42
+
43
+ def build_normalization_layer(
44
+ in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
45
+ ) -> nn.Module:
46
+ """
47
+ Create a normalization layer based on the normalization type.
48
+ Args:
49
+ in_channels: Number of input channels
50
+ num_groups: Number of groups for group normalization
51
+ normtype: Type of normalization: "group" or "pixel"
52
+ Returns:
53
+ A normalization layer
54
+ """
55
+ if normtype == NormType.GROUP:
56
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
57
+ if normtype == NormType.PIXEL:
58
+ return PixelNorm(dim=1, eps=1e-6)
59
+ raise ValueError(f"Invalid normalization type: {normtype}")
packages/ltx-core/src/ltx_core/model/transformer/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer model components."""
2
+
3
+ from ltx_core.model.transformer.modality import Modality
4
+ from ltx_core.model.transformer.model import LTXModel, X0Model
5
+ from ltx_core.model.transformer.model_configurator import (
6
+ LTXV_MODEL_COMFY_RENAMING_MAP,
7
+ LTXModelConfigurator,
8
+ LTXVideoOnlyModelConfigurator,
9
+ )
10
+
11
+ __all__ = [
12
+ "LTXV_MODEL_COMFY_RENAMING_MAP",
13
+ "LTXModel",
14
+ "LTXModelConfigurator",
15
+ "LTXVideoOnlyModelConfigurator",
16
+ "Modality",
17
+ "X0Model",
18
+ ]
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-312.pyc ADDED
Binary file (2.6 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/adaln.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
6
+
7
+ # Number of AdaLN modulation parameters per transformer block.
8
+ # Base: 2 params (shift + scale) x 3 norms (self-attn, feed-forward, output).
9
+ ADALN_NUM_BASE_PARAMS = 6
10
+ # Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
11
+ ADALN_NUM_CROSS_ATTN_PARAMS = 3
12
+
13
+
14
+ def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
15
+ """Total number of AdaLN parameters per block."""
16
+ return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
17
+
18
+
19
+ class AdaLayerNormSingle(torch.nn.Module):
20
+ r"""
21
+ Norm layer adaptive layer norm single (adaLN-single).
22
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
23
+ Parameters:
24
+ embedding_dim (`int`): The size of each embedding vector.
25
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
26
+ """
27
+
28
+ def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):
29
+ super().__init__()
30
+
31
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
32
+ embedding_dim,
33
+ size_emb_dim=embedding_dim // 3,
34
+ )
35
+
36
+ self.silu = torch.nn.SiLU()
37
+ self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
38
+
39
+ def forward(
40
+ self,
41
+ timestep: torch.Tensor,
42
+ hidden_dtype: Optional[torch.dtype] = None,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
45
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
packages/ltx-core/src/ltx_core/model/transformer/attention.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb
7
+
8
+ memory_efficient_attention = None
9
+ flash_attn_interface = None
10
+ try:
11
+ from xformers.ops import memory_efficient_attention
12
+ except ImportError:
13
+ memory_efficient_attention = None
14
+ try:
15
+ # FlashAttention3 and XFormersAttention cannot be used together
16
+ if memory_efficient_attention is None:
17
+ import flash_attn_interface
18
+ except ImportError:
19
+ flash_attn_interface = None
20
+
21
+
22
+ class AttentionCallable(Protocol):
23
+ def __call__(
24
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
25
+ ) -> torch.Tensor: ...
26
+
27
+
28
+ class PytorchAttention(AttentionCallable):
29
+ def __call__(
30
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
31
+ ) -> torch.Tensor:
32
+ b, _, dim_head = q.shape
33
+ dim_head //= heads
34
+ q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v))
35
+
36
+ if mask is not None:
37
+ # add a batch dimension if there isn't already one
38
+ if mask.ndim == 2:
39
+ mask = mask.unsqueeze(0)
40
+ # add a heads dimension if there isn't already one
41
+ if mask.ndim == 3:
42
+ mask = mask.unsqueeze(1)
43
+
44
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
45
+ out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
46
+ return out
47
+
48
+
49
+ class XFormersAttention(AttentionCallable):
50
+ def __call__(
51
+ self,
52
+ q: torch.Tensor,
53
+ k: torch.Tensor,
54
+ v: torch.Tensor,
55
+ heads: int,
56
+ mask: torch.Tensor | None = None,
57
+ ) -> torch.Tensor:
58
+ if memory_efficient_attention is None:
59
+ raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.")
60
+
61
+ b, _, dim_head = q.shape
62
+ dim_head //= heads
63
+
64
+ # xformers expects [B, M, H, K]
65
+ q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
66
+
67
+ if mask is not None:
68
+ # add a singleton batch dimension
69
+ if mask.ndim == 2:
70
+ mask = mask.unsqueeze(0)
71
+ # add a singleton heads dimension
72
+ if mask.ndim == 3:
73
+ mask = mask.unsqueeze(1)
74
+ # pad to a multiple of 8
75
+ pad = 8 - mask.shape[-1] % 8
76
+ # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
77
+ # but when using separated heads, the shape has to be (B, H, Nq, Nk)
78
+ # in flux, this matrix ends up being over 1GB
79
+ # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
80
+ mask_out = torch.empty(
81
+ [mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device
82
+ )
83
+
84
+ mask_out[..., : mask.shape[-1]] = mask
85
+ # doesn't this remove the padding again??
86
+ mask = mask_out[..., : mask.shape[-1]]
87
+ mask = mask.expand(b, heads, -1, -1)
88
+
89
+ out = memory_efficient_attention(q.to(v.dtype), k.to(v.dtype), v, attn_bias=mask, p=0.0)
90
+ out = out.reshape(b, -1, heads * dim_head)
91
+ return out
92
+
93
+
94
+ class FlashAttention3(AttentionCallable):
95
+ def __call__(
96
+ self,
97
+ q: torch.Tensor,
98
+ k: torch.Tensor,
99
+ v: torch.Tensor,
100
+ heads: int,
101
+ mask: torch.Tensor | None = None,
102
+ ) -> torch.Tensor:
103
+ if flash_attn_interface is None:
104
+ raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")
105
+
106
+ b, _, dim_head = q.shape
107
+ dim_head //= heads
108
+
109
+ q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
110
+
111
+ if mask is not None:
112
+ raise NotImplementedError("Mask is not supported for FlashAttention3")
113
+
114
+ out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
115
+ out = out.reshape(b, -1, heads * dim_head)
116
+ return out
117
+
118
+
119
+ class AttentionFunction(Enum):
120
+ PYTORCH = "pytorch"
121
+ XFORMERS = "xformers"
122
+ FLASH_ATTENTION_3 = "flash_attention_3"
123
+ DEFAULT = "default"
124
+
125
+ def __call__(
126
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
127
+ ) -> torch.Tensor:
128
+ if self is AttentionFunction.PYTORCH:
129
+ return PytorchAttention()(q, k, v, heads, mask)
130
+ elif self is AttentionFunction.XFORMERS:
131
+ return XFormersAttention()(q, k, v, heads, mask)
132
+ elif self is AttentionFunction.FLASH_ATTENTION_3:
133
+ return FlashAttention3()(q, k, v, heads, mask)
134
+ else:
135
+ # Default behavior: XFormers if installed else - PyTorch
136
+ return (
137
+ XFormersAttention()(q, k, v, heads, mask)
138
+ if memory_efficient_attention is not None
139
+ else PytorchAttention()(q, k, v, heads, mask)
140
+ )
141
+
142
+
143
+ class Attention(torch.nn.Module):
144
+ def __init__(
145
+ self,
146
+ query_dim: int,
147
+ context_dim: int | None = None,
148
+ heads: int = 8,
149
+ dim_head: int = 64,
150
+ norm_eps: float = 1e-6,
151
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
152
+ attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT,
153
+ apply_gated_attention: bool = False,
154
+ ) -> None:
155
+ super().__init__()
156
+ self.rope_type = rope_type
157
+ self.attention_function = attention_function
158
+
159
+ inner_dim = dim_head * heads
160
+ context_dim = query_dim if context_dim is None else context_dim
161
+
162
+ self.heads = heads
163
+ self.dim_head = dim_head
164
+
165
+ self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
166
+ self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
167
+
168
+ self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
169
+ self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
170
+ self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
171
+
172
+ # Optional per-head gating
173
+ if apply_gated_attention:
174
+ self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
175
+ else:
176
+ self.to_gate_logits = None
177
+
178
+ self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
179
+
180
+ def forward(
181
+ self,
182
+ x: torch.Tensor,
183
+ context: torch.Tensor | None = None,
184
+ mask: torch.Tensor | None = None,
185
+ pe: torch.Tensor | None = None,
186
+ k_pe: torch.Tensor | None = None,
187
+ perturbation_mask: torch.Tensor | None = None,
188
+ all_perturbed: bool = False,
189
+ ) -> torch.Tensor:
190
+ """Multi-head attention with optional RoPE, perturbation masking, and per-head gating.
191
+ When ``perturbation_mask`` is all zeros, the expensive query/key path
192
+ (linear projections, RMSNorm, RoPE) is skipped entirely and only the
193
+ value projection is used as a pass-through.
194
+ Args:
195
+ x: Query input tensor of shape ``(B, T, query_dim)``.
196
+ context: Key/value context tensor of shape ``(B, S, context_dim)``.
197
+ Falls back to ``x`` (self-attention) when *None*.
198
+ mask: Optional attention mask. Interpretation depends on the attention
199
+ backend (additive bias for xformers/PyTorch SDPA).
200
+ pe: Rotary positional embeddings applied to both ``q`` and ``k``.
201
+ k_pe: Separate rotary positional embeddings for ``k`` only. When
202
+ *None*, ``pe`` is reused for keys.
203
+ perturbation_mask: Optional mask in ``[0, 1]`` that
204
+ blends the attention output with the raw value projection:
205
+ ``out = attn_out * mask + v * (1 - mask)``.
206
+ **1** keeps the full attention output, **0** bypasses attention
207
+ and passes the value projection through unchanged.
208
+ *None* or all-ones means standard attention; all-zeros skips
209
+ the query/key path entirely for efficiency.
210
+ all_perturbed: Whether all perturbations are active for this block.
211
+ Returns:
212
+ Output tensor of shape ``(B, T, query_dim)``.
213
+ """
214
+ context = x if context is None else context
215
+ use_attention = not all_perturbed
216
+
217
+ v = self.to_v(context)
218
+
219
+ if not use_attention:
220
+ out = v
221
+ else:
222
+ q = self.to_q(x)
223
+ k = self.to_k(context)
224
+
225
+ q = self.q_norm(q)
226
+ k = self.k_norm(k)
227
+
228
+ if pe is not None:
229
+ q = apply_rotary_emb(q, pe, self.rope_type)
230
+ k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
231
+
232
+ out = self.attention_function(q, k, v, self.heads, mask) # (B, T, H*D)
233
+
234
+ if perturbation_mask is not None:
235
+ out = out * perturbation_mask + v * (1 - perturbation_mask)
236
+
237
+ # Apply per-head gating if enabled
238
+ if self.to_gate_logits is not None:
239
+ gate_logits = self.to_gate_logits(x) # (B, T, H)
240
+ b, t, _ = out.shape
241
+ # Reshape to (B, T, H, D) for per-head gating
242
+ out = out.view(b, t, self.heads, self.dim_head)
243
+ # Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
244
+ gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
245
+ out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
246
+ # Reshape back to (B, T, H*D)
247
+ out = out.view(b, t, self.heads * self.dim_head)
248
+
249
+ return self.to_out(out)
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class GELUApprox(torch.nn.Module):
5
+ def __init__(self, dim_in: int, dim_out: int) -> None:
6
+ super().__init__()
7
+ self.proj = torch.nn.Linear(dim_in, dim_out)
8
+
9
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
10
+ return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
packages/ltx-core/src/ltx_core/model/transformer/modality.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class Modality:
8
+ """
9
+ Input data for a single modality (video or audio) in the transformer.
10
+ Bundles the latent tokens, timestep embeddings, positional information,
11
+ and text conditioning context for processing by the diffusion transformer.
12
+ Attributes:
13
+ latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
14
+ the batch size, *T* is the total number of tokens (noisy +
15
+ conditioning), and *D* is the input dimension.
16
+ timesteps: Per-token timestep embeddings, shape ``(B, T)``.
17
+ positions: Positional coordinates, shape ``(B, 3, T)`` for video
18
+ (time, height, width) or ``(B, 1, T)`` for audio.
19
+ context: Text conditioning embeddings from the prompt encoder.
20
+ enabled: Whether this modality is active in the current forward pass.
21
+ context_mask: Optional mask for the text context tokens.
22
+ attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
23
+ Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
24
+ attention. ``None`` means unrestricted (full) attention between
25
+ all tokens. Built incrementally by conditioning items; see
26
+ :class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
27
+ """
28
+
29
+ latent: (
30
+ torch.Tensor
31
+ ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
32
+ sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
33
+ timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
34
+ positions: (
35
+ torch.Tensor
36
+ ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
37
+ context: torch.Tensor
38
+ enabled: bool = True
39
+ context_mask: torch.Tensor | None = None
40
+ attention_mask: torch.Tensor | None = None
packages/ltx-core/src/ltx_core/model/transformer/model.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.guidance.perturbations import BatchedPerturbationConfig
6
+ from ltx_core.model.transformer.adaln import AdaLayerNormSingle, adaln_embedding_coefficient
7
+ from ltx_core.model.transformer.attention import AttentionCallable, AttentionFunction
8
+ from ltx_core.model.transformer.modality import Modality
9
+ from ltx_core.model.transformer.rope import LTXRopeType
10
+ from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, TransformerConfig
11
+ from ltx_core.model.transformer.transformer_args import (
12
+ MultiModalTransformerArgsPreprocessor,
13
+ TransformerArgs,
14
+ TransformerArgsPreprocessor,
15
+ )
16
+ from ltx_core.utils import to_denoised
17
+
18
+
19
+ class LTXModelType(Enum):
20
+ AudioVideo = "ltx av model"
21
+ VideoOnly = "ltx video only model"
22
+ AudioOnly = "ltx audio only model"
23
+
24
+ def is_video_enabled(self) -> bool:
25
+ return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
26
+
27
+ def is_audio_enabled(self) -> bool:
28
+ return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
29
+
30
+
31
+ class LTXModel(torch.nn.Module):
32
+ """
33
+ LTX model transformer implementation.
34
+ This class implements the transformer blocks for the LTX model.
35
+ """
36
+
37
+ def __init__( # noqa: PLR0913
38
+ self,
39
+ *,
40
+ model_type: LTXModelType = LTXModelType.AudioVideo,
41
+ num_attention_heads: int = 32,
42
+ attention_head_dim: int = 128,
43
+ in_channels: int = 128,
44
+ out_channels: int = 128,
45
+ num_layers: int = 48,
46
+ cross_attention_dim: int = 4096,
47
+ norm_eps: float = 1e-06,
48
+ attention_type: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT,
49
+ positional_embedding_theta: float = 10000.0,
50
+ positional_embedding_max_pos: list[int] | None = None,
51
+ timestep_scale_multiplier: int = 1000,
52
+ use_middle_indices_grid: bool = True,
53
+ audio_num_attention_heads: int = 32,
54
+ audio_attention_head_dim: int = 64,
55
+ audio_in_channels: int = 128,
56
+ audio_out_channels: int = 128,
57
+ audio_cross_attention_dim: int = 2048,
58
+ audio_positional_embedding_max_pos: list[int] | None = None,
59
+ av_ca_timestep_scale_multiplier: int = 1,
60
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
61
+ double_precision_rope: bool = False,
62
+ apply_gated_attention: bool = False,
63
+ caption_projection: torch.nn.Module | None = None,
64
+ audio_caption_projection: torch.nn.Module | None = None,
65
+ cross_attention_adaln: bool = False,
66
+ ):
67
+ super().__init__()
68
+ self._enable_gradient_checkpointing = False
69
+ self.cross_attention_adaln = cross_attention_adaln
70
+ self.use_middle_indices_grid = use_middle_indices_grid
71
+ self.rope_type = rope_type
72
+ self.double_precision_rope = double_precision_rope
73
+ self.timestep_scale_multiplier = timestep_scale_multiplier
74
+ self.positional_embedding_theta = positional_embedding_theta
75
+ self.model_type = model_type
76
+ cross_pe_max_pos = None
77
+ if model_type.is_video_enabled():
78
+ if positional_embedding_max_pos is None:
79
+ positional_embedding_max_pos = [20, 2048, 2048]
80
+ self.positional_embedding_max_pos = positional_embedding_max_pos
81
+ self.num_attention_heads = num_attention_heads
82
+ self.inner_dim = num_attention_heads * attention_head_dim
83
+ self._init_video(
84
+ in_channels=in_channels,
85
+ out_channels=out_channels,
86
+ norm_eps=norm_eps,
87
+ caption_projection=caption_projection,
88
+ )
89
+
90
+ if model_type.is_audio_enabled():
91
+ if audio_positional_embedding_max_pos is None:
92
+ audio_positional_embedding_max_pos = [20]
93
+ self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
94
+ self.audio_num_attention_heads = audio_num_attention_heads
95
+ self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim
96
+ self._init_audio(
97
+ in_channels=audio_in_channels,
98
+ out_channels=audio_out_channels,
99
+ norm_eps=norm_eps,
100
+ caption_projection=audio_caption_projection,
101
+ )
102
+
103
+ if model_type.is_video_enabled() and model_type.is_audio_enabled():
104
+ cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
105
+ self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
106
+ self.audio_cross_attention_dim = audio_cross_attention_dim
107
+ self._init_audio_video(num_scale_shift_values=4)
108
+
109
+ self._init_preprocessors(cross_pe_max_pos)
110
+ # Initialize transformer blocks
111
+ self._init_transformer_blocks(
112
+ num_layers=num_layers,
113
+ attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0,
114
+ cross_attention_dim=cross_attention_dim,
115
+ audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
116
+ audio_cross_attention_dim=audio_cross_attention_dim,
117
+ norm_eps=norm_eps,
118
+ attention_type=attention_type,
119
+ apply_gated_attention=apply_gated_attention,
120
+ )
121
+
122
+ @property
123
+ def _adaln_embedding_coefficient(self) -> int:
124
+ return adaln_embedding_coefficient(self.cross_attention_adaln)
125
+
126
+ def _init_video(
127
+ self,
128
+ in_channels: int,
129
+ out_channels: int,
130
+ norm_eps: float,
131
+ caption_projection: torch.nn.Module | None = None,
132
+ ) -> None:
133
+ """Initialize video-specific components."""
134
+ # Video input components
135
+ self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
136
+ if caption_projection is not None:
137
+ self.caption_projection = caption_projection
138
+
139
+ self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
140
+
141
+ self.prompt_adaln_single = (
142
+ AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
143
+ )
144
+
145
+ # Video output components
146
+ self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
147
+ self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps)
148
+ self.proj_out = torch.nn.Linear(self.inner_dim, out_channels)
149
+
150
+ def _init_audio(
151
+ self,
152
+ in_channels: int,
153
+ out_channels: int,
154
+ norm_eps: float,
155
+ caption_projection: torch.nn.Module | None = None,
156
+ ) -> None:
157
+ """Initialize audio-specific components."""
158
+
159
+ # Audio input components
160
+ self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
161
+ if caption_projection is not None:
162
+ self.audio_caption_projection = caption_projection
163
+
164
+ self.audio_adaln_single = AdaLayerNormSingle(
165
+ self.audio_inner_dim,
166
+ embedding_coefficient=self._adaln_embedding_coefficient,
167
+ )
168
+
169
+ self.audio_prompt_adaln_single = (
170
+ AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
171
+ )
172
+
173
+ # Audio output components
174
+ self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
175
+ self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps)
176
+ self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels)
177
+
178
+ def _init_audio_video(
179
+ self,
180
+ num_scale_shift_values: int,
181
+ ) -> None:
182
+ """Initialize audio-video cross-attention components."""
183
+ self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
184
+ self.inner_dim,
185
+ embedding_coefficient=num_scale_shift_values,
186
+ )
187
+
188
+ self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
189
+ self.audio_inner_dim,
190
+ embedding_coefficient=num_scale_shift_values,
191
+ )
192
+
193
+ self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
194
+ self.inner_dim,
195
+ embedding_coefficient=1,
196
+ )
197
+
198
+ self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
199
+ self.audio_inner_dim,
200
+ embedding_coefficient=1,
201
+ )
202
+
203
+ def _init_preprocessors(
204
+ self,
205
+ cross_pe_max_pos: int | None = None,
206
+ ) -> None:
207
+ """Initialize preprocessors for LTX."""
208
+
209
+ if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
210
+ self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
211
+ patchify_proj=self.patchify_proj,
212
+ adaln=self.adaln_single,
213
+ cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
214
+ cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
215
+ inner_dim=self.inner_dim,
216
+ max_pos=self.positional_embedding_max_pos,
217
+ num_attention_heads=self.num_attention_heads,
218
+ cross_pe_max_pos=cross_pe_max_pos,
219
+ use_middle_indices_grid=self.use_middle_indices_grid,
220
+ audio_cross_attention_dim=self.audio_cross_attention_dim,
221
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
222
+ double_precision_rope=self.double_precision_rope,
223
+ positional_embedding_theta=self.positional_embedding_theta,
224
+ rope_type=self.rope_type,
225
+ av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
226
+ caption_projection=getattr(self, "caption_projection", None),
227
+ prompt_adaln=getattr(self, "prompt_adaln_single", None),
228
+ )
229
+ self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
230
+ patchify_proj=self.audio_patchify_proj,
231
+ adaln=self.audio_adaln_single,
232
+ cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
233
+ cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
234
+ inner_dim=self.audio_inner_dim,
235
+ max_pos=self.audio_positional_embedding_max_pos,
236
+ num_attention_heads=self.audio_num_attention_heads,
237
+ cross_pe_max_pos=cross_pe_max_pos,
238
+ use_middle_indices_grid=self.use_middle_indices_grid,
239
+ audio_cross_attention_dim=self.audio_cross_attention_dim,
240
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
241
+ double_precision_rope=self.double_precision_rope,
242
+ positional_embedding_theta=self.positional_embedding_theta,
243
+ rope_type=self.rope_type,
244
+ av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
245
+ caption_projection=getattr(self, "audio_caption_projection", None),
246
+ prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
247
+ )
248
+ elif self.model_type.is_video_enabled():
249
+ self.video_args_preprocessor = TransformerArgsPreprocessor(
250
+ patchify_proj=self.patchify_proj,
251
+ adaln=self.adaln_single,
252
+ inner_dim=self.inner_dim,
253
+ max_pos=self.positional_embedding_max_pos,
254
+ num_attention_heads=self.num_attention_heads,
255
+ use_middle_indices_grid=self.use_middle_indices_grid,
256
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
257
+ double_precision_rope=self.double_precision_rope,
258
+ positional_embedding_theta=self.positional_embedding_theta,
259
+ rope_type=self.rope_type,
260
+ caption_projection=getattr(self, "caption_projection", None),
261
+ prompt_adaln=getattr(self, "prompt_adaln_single", None),
262
+ )
263
+ elif self.model_type.is_audio_enabled():
264
+ self.audio_args_preprocessor = TransformerArgsPreprocessor(
265
+ patchify_proj=self.audio_patchify_proj,
266
+ adaln=self.audio_adaln_single,
267
+ inner_dim=self.audio_inner_dim,
268
+ max_pos=self.audio_positional_embedding_max_pos,
269
+ num_attention_heads=self.audio_num_attention_heads,
270
+ use_middle_indices_grid=self.use_middle_indices_grid,
271
+ timestep_scale_multiplier=self.timestep_scale_multiplier,
272
+ double_precision_rope=self.double_precision_rope,
273
+ positional_embedding_theta=self.positional_embedding_theta,
274
+ rope_type=self.rope_type,
275
+ caption_projection=getattr(self, "audio_caption_projection", None),
276
+ prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
277
+ )
278
+
279
+ def _init_transformer_blocks(
280
+ self,
281
+ num_layers: int,
282
+ attention_head_dim: int,
283
+ cross_attention_dim: int,
284
+ audio_attention_head_dim: int,
285
+ audio_cross_attention_dim: int,
286
+ norm_eps: float,
287
+ attention_type: AttentionFunction | AttentionCallable,
288
+ apply_gated_attention: bool,
289
+ ) -> None:
290
+ """Initialize transformer blocks for LTX."""
291
+ video_config = (
292
+ TransformerConfig(
293
+ dim=self.inner_dim,
294
+ heads=self.num_attention_heads,
295
+ d_head=attention_head_dim,
296
+ context_dim=cross_attention_dim,
297
+ apply_gated_attention=apply_gated_attention,
298
+ cross_attention_adaln=self.cross_attention_adaln,
299
+ )
300
+ if self.model_type.is_video_enabled()
301
+ else None
302
+ )
303
+ audio_config = (
304
+ TransformerConfig(
305
+ dim=self.audio_inner_dim,
306
+ heads=self.audio_num_attention_heads,
307
+ d_head=audio_attention_head_dim,
308
+ context_dim=audio_cross_attention_dim,
309
+ apply_gated_attention=apply_gated_attention,
310
+ cross_attention_adaln=self.cross_attention_adaln,
311
+ )
312
+ if self.model_type.is_audio_enabled()
313
+ else None
314
+ )
315
+ self.transformer_blocks = torch.nn.ModuleList(
316
+ [
317
+ BasicAVTransformerBlock(
318
+ idx=idx,
319
+ video=video_config,
320
+ audio=audio_config,
321
+ rope_type=self.rope_type,
322
+ norm_eps=norm_eps,
323
+ attention_function=attention_type,
324
+ )
325
+ for idx in range(num_layers)
326
+ ]
327
+ )
328
+
329
+ def set_gradient_checkpointing(self, enable: bool) -> None:
330
+ """Enable or disable gradient checkpointing for transformer blocks.
331
+ Gradient checkpointing trades compute for memory by recomputing activations
332
+ during the backward pass instead of storing them. This can significantly
333
+ reduce memory usage at the cost of ~20-30% slower training.
334
+ Args:
335
+ enable: Whether to enable gradient checkpointing
336
+ """
337
+ self._enable_gradient_checkpointing = enable
338
+
339
+ def _process_transformer_blocks(
340
+ self,
341
+ video: TransformerArgs | None,
342
+ audio: TransformerArgs | None,
343
+ perturbations: BatchedPerturbationConfig,
344
+ ) -> tuple[TransformerArgs, TransformerArgs]:
345
+ """Process transformer blocks for LTXAV."""
346
+
347
+ # Process transformer blocks
348
+ for block in self.transformer_blocks:
349
+ if self._enable_gradient_checkpointing and self.training:
350
+ # Use gradient checkpointing to save memory during training.
351
+ # With use_reentrant=False, we can pass dataclasses directly -
352
+ # PyTorch will track all tensor leaves in the computation graph.
353
+ video, audio = torch.utils.checkpoint.checkpoint(
354
+ block,
355
+ video,
356
+ audio,
357
+ perturbations,
358
+ use_reentrant=False,
359
+ )
360
+ else:
361
+ video, audio = block(
362
+ video=video,
363
+ audio=audio,
364
+ perturbations=perturbations,
365
+ )
366
+
367
+ return video, audio
368
+
369
+ def _process_output(
370
+ self,
371
+ scale_shift_table: torch.Tensor,
372
+ norm_out: torch.nn.LayerNorm,
373
+ proj_out: torch.nn.Linear,
374
+ x: torch.Tensor,
375
+ embedded_timestep: torch.Tensor,
376
+ ) -> torch.Tensor:
377
+ """Process output for LTXV."""
378
+ # Apply scale-shift modulation
379
+ scale_shift_values = (
380
+ scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
381
+ )
382
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
383
+
384
+ x = norm_out(x)
385
+ x = x * (1 + scale) + shift
386
+ x = proj_out(x)
387
+ return x
388
+
389
+ def forward(
390
+ self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
391
+ ) -> tuple[torch.Tensor, torch.Tensor]:
392
+ """
393
+ Forward pass for LTX models.
394
+ Returns:
395
+ Processed output tensors
396
+ """
397
+ if not self.model_type.is_video_enabled() and video is not None:
398
+ raise ValueError("Video is not enabled for this model")
399
+ if not self.model_type.is_audio_enabled() and audio is not None:
400
+ raise ValueError("Audio is not enabled for this model")
401
+
402
+ video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
403
+ audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
404
+ # Process transformer blocks
405
+ video_out, audio_out = self._process_transformer_blocks(
406
+ video=video_args,
407
+ audio=audio_args,
408
+ perturbations=perturbations,
409
+ )
410
+
411
+ # Process output
412
+ vx = (
413
+ self._process_output(
414
+ self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep
415
+ )
416
+ if video_out is not None
417
+ else None
418
+ )
419
+ ax = (
420
+ self._process_output(
421
+ self.audio_scale_shift_table,
422
+ self.audio_norm_out,
423
+ self.audio_proj_out,
424
+ audio_out.x,
425
+ audio_out.embedded_timestep,
426
+ )
427
+ if audio_out is not None
428
+ else None
429
+ )
430
+ return vx, ax
431
+
432
+
433
+ class LegacyX0Model(torch.nn.Module):
434
+ """
435
+ Legacy X0 model implementation.
436
+ Returns fully denoised output based on the velocities produced by the base model.
437
+ """
438
+
439
+ def __init__(self, velocity_model: LTXModel):
440
+ super().__init__()
441
+ self.velocity_model = velocity_model
442
+
443
+ def forward(
444
+ self,
445
+ video: Modality | None,
446
+ audio: Modality | None,
447
+ perturbations: BatchedPerturbationConfig,
448
+ sigma: float,
449
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
450
+ """
451
+ Denoise the video and audio according to the sigma.
452
+ Returns:
453
+ Denoised video and audio
454
+ """
455
+ vx, ax = self.velocity_model(video, audio, perturbations)
456
+ denoised_video = to_denoised(video.latent, vx, sigma) if vx is not None else None
457
+ denoised_audio = to_denoised(audio.latent, ax, sigma) if ax is not None else None
458
+ return denoised_video, denoised_audio
459
+
460
+
461
+ class X0Model(torch.nn.Module):
462
+ """
463
+ X0 model implementation.
464
+ Returns fully denoised outputs based on the velocities produced by the base model.
465
+ Applies scaled denoising to the video and audio according to the timesteps = sigma * denoising_mask.
466
+ """
467
+
468
+ def __init__(self, velocity_model: LTXModel):
469
+ super().__init__()
470
+ self.velocity_model = velocity_model
471
+
472
+ def forward(
473
+ self,
474
+ video: Modality | None,
475
+ audio: Modality | None,
476
+ perturbations: BatchedPerturbationConfig,
477
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
478
+ """
479
+ Denoise the video and audio according to the sigma.
480
+ Returns:
481
+ Denoised video and audio
482
+ """
483
+ vx, ax = self.velocity_model(video, audio, perturbations)
484
+ denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
485
+ denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
486
+ return denoised_video, denoised_audio
packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.loader.sd_ops import SDOps
4
+ from ltx_core.model.model_protocol import ModelConfigurator
5
+ from ltx_core.model.transformer.attention import AttentionFunction
6
+ from ltx_core.model.transformer.model import LTXModel, LTXModelType
7
+ from ltx_core.model.transformer.rope import LTXRopeType
8
+ from ltx_core.model.transformer.text_projection import create_caption_projection
9
+ from ltx_core.utils import check_config_value
10
+
11
+
12
+ class LTXModelConfigurator(ModelConfigurator[LTXModel]):
13
+ """
14
+ Configurator for LTX model.
15
+ Used to create an LTX model from a configuration dictionary.
16
+ """
17
+
18
+ @classmethod
19
+ def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
20
+ # Build caption projections for 19B models (projection handled in transformer).
21
+ caption_projection, audio_caption_projection = _build_caption_projections(config, is_av=True)
22
+
23
+ config = config.get("transformer", {})
24
+
25
+ check_config_value(config, "dropout", 0.0)
26
+ check_config_value(config, "attention_bias", True)
27
+ check_config_value(config, "num_vector_embeds", None)
28
+ check_config_value(config, "activation_fn", "gelu-approximate")
29
+ check_config_value(config, "num_embeds_ada_norm", 1000)
30
+ check_config_value(config, "use_linear_projection", False)
31
+ check_config_value(config, "only_cross_attention", False)
32
+ check_config_value(config, "cross_attention_norm", True)
33
+ check_config_value(config, "double_self_attention", False)
34
+ check_config_value(config, "upcast_attention", False)
35
+ check_config_value(config, "standardization_norm", "rms_norm")
36
+ check_config_value(config, "norm_elementwise_affine", False)
37
+ check_config_value(config, "qk_norm", "rms_norm")
38
+ check_config_value(config, "positional_embedding_type", "rope")
39
+ check_config_value(config, "use_audio_video_cross_attention", True)
40
+ check_config_value(config, "share_ff", False)
41
+ check_config_value(config, "av_cross_ada_norm", True)
42
+ check_config_value(config, "use_middle_indices_grid", True)
43
+
44
+ return LTXModel(
45
+ model_type=LTXModelType.AudioVideo,
46
+ num_attention_heads=config.get("num_attention_heads", 32),
47
+ attention_head_dim=config.get("attention_head_dim", 128),
48
+ in_channels=config.get("in_channels", 128),
49
+ out_channels=config.get("out_channels", 128),
50
+ num_layers=config.get("num_layers", 48),
51
+ cross_attention_dim=config.get("cross_attention_dim", 4096),
52
+ norm_eps=config.get("norm_eps", 1e-06),
53
+ attention_type=AttentionFunction(config.get("attention_type", "default")),
54
+ positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
55
+ positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
56
+ timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
57
+ use_middle_indices_grid=config.get("use_middle_indices_grid", True),
58
+ audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
59
+ audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
60
+ audio_in_channels=config.get("audio_in_channels", 128),
61
+ audio_out_channels=config.get("audio_out_channels", 128),
62
+ audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
63
+ audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
64
+ av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
65
+ rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
66
+ double_precision_rope=config.get("frequencies_precision", False) == "float64",
67
+ apply_gated_attention=config.get("apply_gated_attention", False),
68
+ caption_projection=caption_projection,
69
+ audio_caption_projection=audio_caption_projection,
70
+ cross_attention_adaln=config.get("cross_attention_adaln", False),
71
+ )
72
+
73
+
74
+ class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
75
+ """
76
+ Configurator for LTX video only model.
77
+ Used to create an LTX video only model from a configuration dictionary.
78
+ """
79
+
80
+ @classmethod
81
+ def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
82
+ # Build caption projection for 19B model (projection handled in transformer).
83
+ caption_projection, _ = _build_caption_projections(config, is_av=False)
84
+
85
+ config = config.get("transformer", {})
86
+
87
+ check_config_value(config, "dropout", 0.0)
88
+ check_config_value(config, "attention_bias", True)
89
+ check_config_value(config, "num_vector_embeds", None)
90
+ check_config_value(config, "activation_fn", "gelu-approximate")
91
+ check_config_value(config, "num_embeds_ada_norm", 1000)
92
+ check_config_value(config, "use_linear_projection", False)
93
+ check_config_value(config, "only_cross_attention", False)
94
+ check_config_value(config, "cross_attention_norm", True)
95
+ check_config_value(config, "double_self_attention", False)
96
+ check_config_value(config, "upcast_attention", False)
97
+ check_config_value(config, "standardization_norm", "rms_norm")
98
+ check_config_value(config, "norm_elementwise_affine", False)
99
+ check_config_value(config, "qk_norm", "rms_norm")
100
+ check_config_value(config, "positional_embedding_type", "rope")
101
+ check_config_value(config, "use_middle_indices_grid", True)
102
+
103
+ return LTXModel(
104
+ model_type=LTXModelType.VideoOnly,
105
+ num_attention_heads=config.get("num_attention_heads", 32),
106
+ attention_head_dim=config.get("attention_head_dim", 128),
107
+ in_channels=config.get("in_channels", 128),
108
+ out_channels=config.get("out_channels", 128),
109
+ num_layers=config.get("num_layers", 48),
110
+ cross_attention_dim=config.get("cross_attention_dim", 4096),
111
+ norm_eps=config.get("norm_eps", 1e-06),
112
+ attention_type=AttentionFunction(config.get("attention_type", "default")),
113
+ positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
114
+ positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
115
+ timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
116
+ use_middle_indices_grid=config.get("use_middle_indices_grid", True),
117
+ rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
118
+ double_precision_rope=config.get("frequencies_precision", False) == "float64",
119
+ apply_gated_attention=config.get("apply_gated_attention", False),
120
+ caption_projection=caption_projection,
121
+ cross_attention_adaln=config.get("cross_attention_adaln", False),
122
+ )
123
+
124
+
125
+ def _build_caption_projections(
126
+ config: dict,
127
+ is_av: bool,
128
+ ) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
129
+ """Build caption projections for the transformer when projection is NOT in the text encoder.
130
+ 19B models: projection is in the transformer (caption_proj_before_connector=False).
131
+ 22B models: projection is in the text encoder, so no projections are created here.
132
+ Args:
133
+ config: Full model config dict (must contain "transformer" key).
134
+ is_av: Whether this is an audio-video model. When False, audio projection is skipped.
135
+ Returns:
136
+ Tuple of (video_caption_projection, audio_caption_projection), both None for 22B models.
137
+ """
138
+ transformer_config = config.get("transformer", {})
139
+ if transformer_config.get("caption_proj_before_connector", False):
140
+ return None, None
141
+
142
+ with torch.device("meta"):
143
+ caption_projection = create_caption_projection(transformer_config)
144
+ audio_caption_projection = create_caption_projection(transformer_config, audio=True) if is_av else None
145
+ return caption_projection, audio_caption_projection
146
+
147
+
148
+ LTXV_MODEL_COMFY_RENAMING_MAP = (
149
+ SDOps("LTXV_MODEL_COMFY_PREFIX_MAP")
150
+ .with_matching(prefix="model.diffusion_model.")
151
+ .with_replacement("model.diffusion_model.", "")
152
+ )
packages/ltx-core/src/ltx_core/model/transformer/rope.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ from enum import Enum
4
+ from typing import Callable, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from einops import rearrange
9
+
10
+
11
+ class LTXRopeType(Enum):
12
+ INTERLEAVED = "interleaved"
13
+ SPLIT = "split"
14
+
15
+
16
+ def apply_rotary_emb(
17
+ input_tensor: torch.Tensor,
18
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor],
19
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
20
+ ) -> torch.Tensor:
21
+ if rope_type == LTXRopeType.INTERLEAVED:
22
+ return apply_interleaved_rotary_emb(input_tensor, *freqs_cis)
23
+ elif rope_type == LTXRopeType.SPLIT:
24
+ return apply_split_rotary_emb(input_tensor, *freqs_cis)
25
+ else:
26
+ raise ValueError(f"Invalid rope type: {rope_type}")
27
+
28
+
29
+ def apply_interleaved_rotary_emb(
30
+ input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor
31
+ ) -> torch.Tensor:
32
+ t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
33
+ t1, t2 = t_dup.unbind(dim=-1)
34
+ t_dup = torch.stack((-t2, t1), dim=-1)
35
+ input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
36
+
37
+ out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
38
+
39
+ return out
40
+
41
+
42
+ def apply_split_rotary_emb(
43
+ input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor
44
+ ) -> torch.Tensor:
45
+ needs_reshape = False
46
+ if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
47
+ b, h, t, _ = cos_freqs.shape
48
+ input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2)
49
+ needs_reshape = True
50
+
51
+ split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
52
+ first_half_input = split_input[..., :1, :]
53
+ second_half_input = split_input[..., 1:, :]
54
+
55
+ output = split_input * cos_freqs.unsqueeze(-2)
56
+ first_half_output = output[..., :1, :]
57
+ second_half_output = output[..., 1:, :]
58
+
59
+ first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input)
60
+ second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input)
61
+
62
+ output = rearrange(output, "... d r -> ... (d r)")
63
+ if needs_reshape:
64
+ output = output.swapaxes(1, 2).reshape(b, t, -1)
65
+
66
+ return output
67
+
68
+
69
+ @functools.lru_cache(maxsize=5)
70
+ def generate_freq_grid_np(
71
+ positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int
72
+ ) -> torch.Tensor:
73
+ theta = positional_embedding_theta
74
+ start = 1
75
+ end = theta
76
+
77
+ n_elem = 2 * positional_embedding_max_pos_count
78
+ pow_indices = np.power(
79
+ theta,
80
+ np.linspace(
81
+ np.log(start) / np.log(theta),
82
+ np.log(end) / np.log(theta),
83
+ inner_dim // n_elem,
84
+ dtype=np.float64,
85
+ ),
86
+ )
87
+ return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
88
+
89
+
90
+ @functools.lru_cache(maxsize=5)
91
+ def generate_freq_grid_pytorch(
92
+ positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int
93
+ ) -> torch.Tensor:
94
+ theta = positional_embedding_theta
95
+ start = 1
96
+ end = theta
97
+ n_elem = 2 * positional_embedding_max_pos_count
98
+
99
+ indices = theta ** (
100
+ torch.linspace(
101
+ math.log(start, theta),
102
+ math.log(end, theta),
103
+ inner_dim // n_elem,
104
+ dtype=torch.float32,
105
+ )
106
+ )
107
+ indices = indices.to(dtype=torch.float32)
108
+
109
+ indices = indices * math.pi / 2
110
+
111
+ return indices
112
+
113
+
114
+ def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor:
115
+ n_pos_dims = indices_grid.shape[1]
116
+ assert n_pos_dims == len(max_pos), (
117
+ f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
118
+ )
119
+ fractional_positions = torch.stack(
120
+ [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
121
+ dim=-1,
122
+ )
123
+ return fractional_positions
124
+
125
+
126
+ def generate_freqs(
127
+ indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool
128
+ ) -> torch.Tensor:
129
+ if use_middle_indices_grid:
130
+ assert len(indices_grid.shape) == 4
131
+ assert indices_grid.shape[-1] == 2
132
+ indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
133
+ indices_grid = (indices_grid_start + indices_grid_end) / 2.0
134
+ elif len(indices_grid.shape) == 4:
135
+ indices_grid = indices_grid[..., 0]
136
+
137
+ fractional_positions = get_fractional_positions(indices_grid, max_pos)
138
+ indices = indices.to(device=fractional_positions.device)
139
+
140
+ freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
141
+ return freqs
142
+
143
+
144
+ def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]:
145
+ cos_freq = freqs.cos()
146
+ sin_freq = freqs.sin()
147
+
148
+ if pad_size != 0:
149
+ cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
150
+ sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
151
+
152
+ cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
153
+ sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
154
+
155
+ # Reshape freqs to be compatible with multi-head attention
156
+ b = cos_freq.shape[0]
157
+ t = cos_freq.shape[1]
158
+
159
+ cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
160
+ sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
161
+
162
+ cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
163
+ sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
164
+ return cos_freq, sin_freq
165
+
166
+
167
+ def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]:
168
+ cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
169
+ sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
170
+ if pad_size != 0:
171
+ cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
172
+ sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size])
173
+ cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
174
+ sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
175
+ return cos_freq, sin_freq
176
+
177
+
178
+ def precompute_freqs_cis(
179
+ indices_grid: torch.Tensor,
180
+ dim: int,
181
+ out_dtype: torch.dtype,
182
+ theta: float = 10000.0,
183
+ max_pos: list[int] | None = None,
184
+ use_middle_indices_grid: bool = False,
185
+ num_attention_heads: int = 32,
186
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
187
+ freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch,
188
+ ) -> tuple[torch.Tensor, torch.Tensor]:
189
+ if max_pos is None:
190
+ max_pos = [20, 2048, 2048]
191
+
192
+ indices = freq_grid_generator(theta, indices_grid.shape[1], dim)
193
+ freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
194
+
195
+ if rope_type == LTXRopeType.SPLIT:
196
+ expected_freqs = dim // 2
197
+ current_freqs = freqs.shape[-1]
198
+ pad_size = expected_freqs - current_freqs
199
+ cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
200
+ else:
201
+ # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
202
+ n_elem = 2 * indices_grid.shape[1]
203
+ cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
204
+ return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
packages/ltx-core/src/ltx_core/model/transformer/text_projection.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class PixArtAlphaTextProjection(torch.nn.Module):
5
+ """
6
+ Projects caption embeddings using dual linear layers.
7
+ Flow: linear_1 → activation → linear_2
8
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
9
+ """
10
+
11
+ def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"):
12
+ super().__init__()
13
+ if out_features is None:
14
+ out_features = hidden_size
15
+ self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
16
+ if act_fn == "gelu_tanh":
17
+ self.act_1 = torch.nn.GELU(approximate="tanh")
18
+ elif act_fn == "silu":
19
+ self.act_1 = torch.nn.SiLU()
20
+ else:
21
+ raise ValueError(f"Unknown activation function: {act_fn}")
22
+ self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
23
+
24
+ def forward(self, caption: torch.Tensor) -> torch.Tensor:
25
+ hidden_states = self.linear_1(caption)
26
+ hidden_states = self.act_1(hidden_states)
27
+ hidden_states = self.linear_2(hidden_states)
28
+ return hidden_states
29
+
30
+
31
+ def create_caption_projection(transformer_config: dict, audio: bool = False) -> PixArtAlphaTextProjection:
32
+ """Create a caption projection for the transformer (V1/19B only)."""
33
+ caption_channels = transformer_config["caption_channels"]
34
+ if audio:
35
+ inner_dim = transformer_config["audio_num_attention_heads"] * transformer_config["audio_attention_head_dim"]
36
+ else:
37
+ inner_dim = transformer_config["num_attention_heads"] * transformer_config["attention_head_dim"]
38
+ return PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ def get_timestep_embedding(
7
+ timesteps: torch.Tensor,
8
+ embedding_dim: int,
9
+ flip_sin_to_cos: bool = False,
10
+ downscale_freq_shift: float = 1,
11
+ scale: float = 1,
12
+ max_period: int = 10000,
13
+ ) -> torch.Tensor:
14
+ """
15
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
16
+ Args
17
+ timesteps (torch.Tensor):
18
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
19
+ embedding_dim (int):
20
+ the dimension of the output.
21
+ flip_sin_to_cos (bool):
22
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
23
+ downscale_freq_shift (float):
24
+ Controls the delta between frequencies between dimensions
25
+ scale (float):
26
+ Scaling factor applied to the embeddings.
27
+ max_period (int):
28
+ Controls the maximum frequency of the embeddings
29
+ Returns
30
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
31
+ """
32
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
33
+
34
+ half_dim = embedding_dim // 2
35
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
36
+ exponent = exponent / (half_dim - downscale_freq_shift)
37
+
38
+ emb = torch.exp(exponent)
39
+ emb = timesteps[:, None].float() * emb[None, :]
40
+
41
+ # scale embeddings
42
+ emb = scale * emb
43
+
44
+ # concat sine and cosine embeddings
45
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
46
+
47
+ # flip sine and cosine embeddings
48
+ if flip_sin_to_cos:
49
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
50
+
51
+ # zero pad
52
+ if embedding_dim % 2 == 1:
53
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
54
+ return emb
55
+
56
+
57
+ class TimestepEmbedding(torch.nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ time_embed_dim: int,
62
+ out_dim: int | None = None,
63
+ post_act_fn: str | None = None,
64
+ cond_proj_dim: int | None = None,
65
+ sample_proj_bias: bool = True,
66
+ ):
67
+ super().__init__()
68
+
69
+ self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
70
+
71
+ if cond_proj_dim is not None:
72
+ self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False)
73
+ else:
74
+ self.cond_proj = None
75
+
76
+ self.act = torch.nn.SiLU()
77
+ time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
78
+
79
+ self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
80
+
81
+ if post_act_fn is None:
82
+ self.post_act = None
83
+
84
+ def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor:
85
+ if condition is not None:
86
+ sample = sample + self.cond_proj(condition)
87
+ sample = self.linear_1(sample)
88
+
89
+ if self.act is not None:
90
+ sample = self.act(sample)
91
+
92
+ sample = self.linear_2(sample)
93
+
94
+ if self.post_act is not None:
95
+ sample = self.post_act(sample)
96
+ return sample
97
+
98
+
99
+ class Timesteps(torch.nn.Module):
100
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
101
+ super().__init__()
102
+ self.num_channels = num_channels
103
+ self.flip_sin_to_cos = flip_sin_to_cos
104
+ self.downscale_freq_shift = downscale_freq_shift
105
+ self.scale = scale
106
+
107
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
108
+ t_emb = get_timestep_embedding(
109
+ timesteps,
110
+ self.num_channels,
111
+ flip_sin_to_cos=self.flip_sin_to_cos,
112
+ downscale_freq_shift=self.downscale_freq_shift,
113
+ scale=self.scale,
114
+ )
115
+ return t_emb
116
+
117
+
118
+ class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
119
+ """
120
+ For PixArt-Alpha.
121
+ Reference:
122
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ embedding_dim: int,
128
+ size_emb_dim: int,
129
+ ):
130
+ super().__init__()
131
+
132
+ self.outdim = size_emb_dim
133
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
134
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
135
+
136
+ def forward(
137
+ self,
138
+ timestep: torch.Tensor,
139
+ hidden_dtype: torch.dtype,
140
+ ) -> torch.Tensor:
141
+ timesteps_proj = self.time_proj(timestep)
142
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
143
+ return timesteps_emb
packages/ltx-core/src/ltx_core/model/transformer/transformer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+
3
+ import torch
4
+
5
+ from ltx_core.guidance.perturbations import BatchedPerturbationConfig, PerturbationType
6
+ from ltx_core.model.transformer.adaln import adaln_embedding_coefficient
7
+ from ltx_core.model.transformer.attention import Attention, AttentionCallable, AttentionFunction
8
+ from ltx_core.model.transformer.feed_forward import FeedForward
9
+ from ltx_core.model.transformer.rope import LTXRopeType
10
+ from ltx_core.model.transformer.transformer_args import TransformerArgs
11
+ from ltx_core.utils import rms_norm
12
+
13
+
14
+ @dataclass
15
+ class TransformerConfig:
16
+ dim: int
17
+ heads: int
18
+ d_head: int
19
+ context_dim: int
20
+ apply_gated_attention: bool = False
21
+ cross_attention_adaln: bool = False
22
+
23
+
24
+ class BasicAVTransformerBlock(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ idx: int,
28
+ video: TransformerConfig | None = None,
29
+ audio: TransformerConfig | None = None,
30
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
31
+ norm_eps: float = 1e-6,
32
+ attention_function: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.idx = idx
37
+ if video is not None:
38
+ self.attn1 = Attention(
39
+ query_dim=video.dim,
40
+ heads=video.heads,
41
+ dim_head=video.d_head,
42
+ context_dim=None,
43
+ rope_type=rope_type,
44
+ norm_eps=norm_eps,
45
+ attention_function=attention_function,
46
+ apply_gated_attention=video.apply_gated_attention,
47
+ )
48
+ self.attn2 = Attention(
49
+ query_dim=video.dim,
50
+ context_dim=video.context_dim,
51
+ heads=video.heads,
52
+ dim_head=video.d_head,
53
+ rope_type=rope_type,
54
+ norm_eps=norm_eps,
55
+ attention_function=attention_function,
56
+ apply_gated_attention=video.apply_gated_attention,
57
+ )
58
+ self.ff = FeedForward(video.dim, dim_out=video.dim)
59
+ video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
60
+ self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
61
+
62
+ if audio is not None:
63
+ self.audio_attn1 = Attention(
64
+ query_dim=audio.dim,
65
+ heads=audio.heads,
66
+ dim_head=audio.d_head,
67
+ context_dim=None,
68
+ rope_type=rope_type,
69
+ norm_eps=norm_eps,
70
+ attention_function=attention_function,
71
+ apply_gated_attention=audio.apply_gated_attention,
72
+ )
73
+ self.audio_attn2 = Attention(
74
+ query_dim=audio.dim,
75
+ context_dim=audio.context_dim,
76
+ heads=audio.heads,
77
+ dim_head=audio.d_head,
78
+ rope_type=rope_type,
79
+ norm_eps=norm_eps,
80
+ attention_function=attention_function,
81
+ apply_gated_attention=audio.apply_gated_attention,
82
+ )
83
+ self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
84
+ audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
85
+ self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
86
+
87
+ if audio is not None and video is not None:
88
+ # Q: Video, K,V: Audio
89
+ self.audio_to_video_attn = Attention(
90
+ query_dim=video.dim,
91
+ context_dim=audio.dim,
92
+ heads=audio.heads,
93
+ dim_head=audio.d_head,
94
+ rope_type=rope_type,
95
+ norm_eps=norm_eps,
96
+ attention_function=attention_function,
97
+ apply_gated_attention=video.apply_gated_attention,
98
+ )
99
+
100
+ # Q: Audio, K,V: Video
101
+ self.video_to_audio_attn = Attention(
102
+ query_dim=audio.dim,
103
+ context_dim=video.dim,
104
+ heads=audio.heads,
105
+ dim_head=audio.d_head,
106
+ rope_type=rope_type,
107
+ norm_eps=norm_eps,
108
+ attention_function=attention_function,
109
+ apply_gated_attention=audio.apply_gated_attention,
110
+ )
111
+
112
+ self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
113
+ self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
114
+
115
+ self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
116
+ audio is not None and audio.cross_attention_adaln
117
+ )
118
+
119
+ if self.cross_attention_adaln and video is not None:
120
+ self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
121
+ if self.cross_attention_adaln and audio is not None:
122
+ self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
123
+
124
+ self.norm_eps = norm_eps
125
+
126
+ def get_ada_values(
127
+ self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice
128
+ ) -> tuple[torch.Tensor, ...]:
129
+ num_ada_params = scale_shift_table.shape[0]
130
+
131
+ ada_values = (
132
+ scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
133
+ + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
134
+ ).unbind(dim=2)
135
+ return ada_values
136
+
137
+ def get_av_ca_ada_values(
138
+ self,
139
+ scale_shift_table: torch.Tensor,
140
+ batch_size: int,
141
+ scale_shift_timestep: torch.Tensor,
142
+ gate_timestep: torch.Tensor,
143
+ scale_shift_indices: slice,
144
+ num_scale_shift_values: int = 4,
145
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ scale_shift_ada_values = self.get_ada_values(
147
+ scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
148
+ )
149
+ gate_ada_values = self.get_ada_values(
150
+ scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
151
+ )
152
+
153
+ scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
154
+ (gate,) = (t.squeeze(2) for t in gate_ada_values)
155
+
156
+ return scale, shift, gate
157
+
158
+ def _apply_text_cross_attention(
159
+ self,
160
+ x: torch.Tensor,
161
+ context: torch.Tensor,
162
+ attn: AttentionCallable,
163
+ scale_shift_table: torch.Tensor,
164
+ prompt_scale_shift_table: torch.Tensor | None,
165
+ timestep: torch.Tensor,
166
+ prompt_timestep: torch.Tensor | None,
167
+ context_mask: torch.Tensor | None,
168
+ cross_attention_adaln: bool = False,
169
+ ) -> torch.Tensor:
170
+ """Apply text cross-attention, with optional AdaLN modulation."""
171
+ if cross_attention_adaln:
172
+ shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
173
+ return apply_cross_attention_adaln(
174
+ x,
175
+ context,
176
+ attn,
177
+ shift_q,
178
+ scale_q,
179
+ gate,
180
+ prompt_scale_shift_table,
181
+ prompt_timestep,
182
+ context_mask,
183
+ self.norm_eps,
184
+ )
185
+ return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
186
+
187
+ def forward( # noqa: PLR0915
188
+ self,
189
+ video: TransformerArgs | None,
190
+ audio: TransformerArgs | None,
191
+ perturbations: BatchedPerturbationConfig | None = None,
192
+ ) -> tuple[TransformerArgs | None, TransformerArgs | None]:
193
+ if video is None and audio is None:
194
+ raise ValueError("At least one of video or audio must be provided")
195
+
196
+ batch_size = (video or audio).x.shape[0]
197
+
198
+ if perturbations is None:
199
+ perturbations = BatchedPerturbationConfig.empty(batch_size)
200
+
201
+ vx = video.x if video is not None else None
202
+ ax = audio.x if audio is not None else None
203
+
204
+ run_vx = video is not None and video.enabled and vx.numel() > 0
205
+ run_ax = audio is not None and audio.enabled and ax.numel() > 0
206
+
207
+ run_a2v = run_vx and (audio is not None and ax.numel() > 0)
208
+ run_v2a = run_ax and (video is not None and vx.numel() > 0)
209
+
210
+ if run_vx:
211
+ vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
212
+ self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
213
+ )
214
+ norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
215
+ del vshift_msa, vscale_msa
216
+
217
+ all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
218
+ none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
219
+ v_mask = (
220
+ perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
221
+ if not all_perturbed and not none_perturbed
222
+ else None
223
+ )
224
+ vx = (
225
+ vx
226
+ + self.attn1(
227
+ norm_vx,
228
+ pe=video.positional_embeddings,
229
+ mask=video.self_attention_mask,
230
+ perturbation_mask=v_mask,
231
+ all_perturbed=all_perturbed,
232
+ )
233
+ * vgate_msa
234
+ )
235
+ del vgate_msa, norm_vx, v_mask
236
+ vx = vx + self._apply_text_cross_attention(
237
+ vx,
238
+ video.context,
239
+ self.attn2,
240
+ self.scale_shift_table,
241
+ getattr(self, "prompt_scale_shift_table", None),
242
+ video.timesteps,
243
+ video.prompt_timestep,
244
+ video.context_mask,
245
+ cross_attention_adaln=self.cross_attention_adaln,
246
+ )
247
+
248
+ if run_ax:
249
+ ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
250
+ self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
251
+ )
252
+
253
+ norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
254
+ del ashift_msa, ascale_msa
255
+ all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
256
+ none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
257
+ a_mask = (
258
+ perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
259
+ if not all_perturbed and not none_perturbed
260
+ else None
261
+ )
262
+ ax = (
263
+ ax
264
+ + self.audio_attn1(
265
+ norm_ax,
266
+ pe=audio.positional_embeddings,
267
+ mask=audio.self_attention_mask,
268
+ perturbation_mask=a_mask,
269
+ all_perturbed=all_perturbed,
270
+ )
271
+ * agate_msa
272
+ )
273
+ del agate_msa, norm_ax, a_mask
274
+ ax = ax + self._apply_text_cross_attention(
275
+ ax,
276
+ audio.context,
277
+ self.audio_attn2,
278
+ self.audio_scale_shift_table,
279
+ getattr(self, "audio_prompt_scale_shift_table", None),
280
+ audio.timesteps,
281
+ audio.prompt_timestep,
282
+ audio.context_mask,
283
+ cross_attention_adaln=self.cross_attention_adaln,
284
+ )
285
+
286
+ # Audio - Video cross attention.
287
+ if run_a2v or run_v2a:
288
+ vx_norm3 = rms_norm(vx, eps=self.norm_eps)
289
+ ax_norm3 = rms_norm(ax, eps=self.norm_eps)
290
+
291
+ if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
292
+ scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
293
+ self.scale_shift_table_a2v_ca_video,
294
+ vx.shape[0],
295
+ video.cross_scale_shift_timestep,
296
+ video.cross_gate_timestep,
297
+ slice(0, 2),
298
+ )
299
+ vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
300
+ del scale_ca_video_a2v, shift_ca_video_a2v
301
+
302
+ scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
303
+ self.scale_shift_table_a2v_ca_audio,
304
+ ax.shape[0],
305
+ audio.cross_scale_shift_timestep,
306
+ audio.cross_gate_timestep,
307
+ slice(0, 2),
308
+ )
309
+ ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
310
+ del scale_ca_audio_a2v, shift_ca_audio_a2v
311
+ a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
312
+ vx = vx + (
313
+ self.audio_to_video_attn(
314
+ vx_scaled,
315
+ context=ax_scaled,
316
+ pe=video.cross_positional_embeddings,
317
+ k_pe=audio.cross_positional_embeddings,
318
+ )
319
+ * gate_out_a2v
320
+ * a2v_mask
321
+ )
322
+ del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
323
+
324
+ if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
325
+ scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
326
+ self.scale_shift_table_a2v_ca_audio,
327
+ ax.shape[0],
328
+ audio.cross_scale_shift_timestep,
329
+ audio.cross_gate_timestep,
330
+ slice(2, 4),
331
+ )
332
+ ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
333
+ del scale_ca_audio_v2a, shift_ca_audio_v2a
334
+ scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
335
+ self.scale_shift_table_a2v_ca_video,
336
+ vx.shape[0],
337
+ video.cross_scale_shift_timestep,
338
+ video.cross_gate_timestep,
339
+ slice(2, 4),
340
+ )
341
+ vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
342
+ del scale_ca_video_v2a, shift_ca_video_v2a
343
+ v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
344
+ ax = ax + (
345
+ self.video_to_audio_attn(
346
+ ax_scaled,
347
+ context=vx_scaled,
348
+ pe=audio.cross_positional_embeddings,
349
+ k_pe=video.cross_positional_embeddings,
350
+ )
351
+ * gate_out_v2a
352
+ * v2a_mask
353
+ )
354
+ del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
355
+
356
+ del vx_norm3, ax_norm3
357
+
358
+ if run_vx:
359
+ vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
360
+ self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
361
+ )
362
+ vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
363
+ vx = vx + self.ff(vx_scaled) * vgate_mlp
364
+
365
+ del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
366
+
367
+ if run_ax:
368
+ ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
369
+ self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
370
+ )
371
+ ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
372
+ ax = ax + self.audio_ff(ax_scaled) * agate_mlp
373
+
374
+ del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
375
+
376
+ return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
377
+
378
+
379
+ def apply_cross_attention_adaln(
380
+ x: torch.Tensor,
381
+ context: torch.Tensor,
382
+ attn: AttentionCallable,
383
+ q_shift: torch.Tensor,
384
+ q_scale: torch.Tensor,
385
+ q_gate: torch.Tensor,
386
+ prompt_scale_shift_table: torch.Tensor,
387
+ prompt_timestep: torch.Tensor,
388
+ context_mask: torch.Tensor | None = None,
389
+ norm_eps: float = 1e-6,
390
+ ) -> torch.Tensor:
391
+ batch_size = x.shape[0]
392
+ shift_kv, scale_kv = (
393
+ prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
394
+ + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
395
+ ).unbind(dim=2)
396
+ attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
397
+ encoder_hidden_states = context * (1 + scale_kv) + shift_kv
398
+ return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.transformer.adaln import AdaLayerNormSingle
6
+ from ltx_core.model.transformer.modality import Modality
7
+ from ltx_core.model.transformer.rope import (
8
+ LTXRopeType,
9
+ generate_freq_grid_np,
10
+ generate_freq_grid_pytorch,
11
+ precompute_freqs_cis,
12
+ )
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class TransformerArgs:
17
+ x: torch.Tensor
18
+ context: torch.Tensor
19
+ context_mask: torch.Tensor
20
+ timesteps: torch.Tensor
21
+ embedded_timestep: torch.Tensor
22
+ positional_embeddings: torch.Tensor
23
+ cross_positional_embeddings: torch.Tensor | None
24
+ cross_scale_shift_timestep: torch.Tensor | None
25
+ cross_gate_timestep: torch.Tensor | None
26
+ enabled: bool
27
+ prompt_timestep: torch.Tensor | None = None
28
+ self_attention_mask: torch.Tensor | None = (
29
+ None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
30
+ )
31
+
32
+
33
+ class TransformerArgsPreprocessor:
34
+ def __init__( # noqa: PLR0913
35
+ self,
36
+ patchify_proj: torch.nn.Linear,
37
+ adaln: AdaLayerNormSingle,
38
+ inner_dim: int,
39
+ max_pos: list[int],
40
+ num_attention_heads: int,
41
+ use_middle_indices_grid: bool,
42
+ timestep_scale_multiplier: int,
43
+ double_precision_rope: bool,
44
+ positional_embedding_theta: float,
45
+ rope_type: LTXRopeType,
46
+ caption_projection: torch.nn.Module | None = None,
47
+ prompt_adaln: AdaLayerNormSingle | None = None,
48
+ ) -> None:
49
+ self.patchify_proj = patchify_proj
50
+ self.adaln = adaln
51
+ self.inner_dim = inner_dim
52
+ self.max_pos = max_pos
53
+ self.num_attention_heads = num_attention_heads
54
+ self.use_middle_indices_grid = use_middle_indices_grid
55
+ self.timestep_scale_multiplier = timestep_scale_multiplier
56
+ self.double_precision_rope = double_precision_rope
57
+ self.positional_embedding_theta = positional_embedding_theta
58
+ self.rope_type = rope_type
59
+ self.caption_projection = caption_projection
60
+ self.prompt_adaln = prompt_adaln
61
+
62
+ def _prepare_timestep(
63
+ self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
64
+ ) -> tuple[torch.Tensor, torch.Tensor]:
65
+ """Prepare timestep embeddings."""
66
+ timestep_scaled = timestep * self.timestep_scale_multiplier
67
+ timestep, embedded_timestep = adaln(
68
+ timestep_scaled.flatten(),
69
+ hidden_dtype=hidden_dtype,
70
+ )
71
+ # Second dimension is 1 or number of tokens (if timestep_per_token)
72
+ timestep = timestep.view(batch_size, -1, timestep.shape[-1])
73
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
74
+
75
+ return timestep, embedded_timestep
76
+
77
+ def _prepare_context(
78
+ self,
79
+ context: torch.Tensor,
80
+ x: torch.Tensor,
81
+ ) -> torch.Tensor:
82
+ """Prepare context for transformer blocks."""
83
+ if self.caption_projection is not None:
84
+ context = self.caption_projection(context)
85
+ batch_size = x.shape[0]
86
+ return context.view(batch_size, -1, x.shape[-1])
87
+
88
+ def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
89
+ """Prepare attention mask."""
90
+ if attention_mask is None or torch.is_floating_point(attention_mask):
91
+ return attention_mask
92
+
93
+ return (attention_mask - 1).to(x_dtype).reshape(
94
+ (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
95
+ ) * torch.finfo(x_dtype).max
96
+
97
+ def _prepare_self_attention_mask(
98
+ self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
99
+ ) -> torch.Tensor | None:
100
+ """Prepare self-attention mask by converting [0,1] values to additive log-space bias.
101
+ Input shape: (B, T, T) with values in [0, 1].
102
+ Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
103
+ for masked positions.
104
+ Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
105
+ representable value). Strictly positive entries are converted via log-space for
106
+ smooth attenuation, with small values clamped for numerical stability.
107
+ Returns None if input is None (no masking).
108
+ """
109
+ if attention_mask is None:
110
+ return None
111
+
112
+ # Convert [0, 1] attention mask to additive log-space bias:
113
+ # 1.0 -> log(1.0) = 0.0 (no bias, full attention)
114
+ # 0.0 -> finfo.min (fully masked)
115
+ finfo = torch.finfo(x_dtype)
116
+ eps = finfo.tiny
117
+
118
+ bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
119
+ positive = attention_mask > 0
120
+ if positive.any():
121
+ bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
122
+
123
+ return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
124
+
125
+ def _prepare_positional_embeddings(
126
+ self,
127
+ positions: torch.Tensor,
128
+ inner_dim: int,
129
+ max_pos: list[int],
130
+ use_middle_indices_grid: bool,
131
+ num_attention_heads: int,
132
+ x_dtype: torch.dtype,
133
+ ) -> torch.Tensor:
134
+ """Prepare positional embeddings."""
135
+ freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
136
+ pe = precompute_freqs_cis(
137
+ positions,
138
+ dim=inner_dim,
139
+ out_dtype=x_dtype,
140
+ theta=self.positional_embedding_theta,
141
+ max_pos=max_pos,
142
+ use_middle_indices_grid=use_middle_indices_grid,
143
+ num_attention_heads=num_attention_heads,
144
+ rope_type=self.rope_type,
145
+ freq_grid_generator=freq_grid_generator,
146
+ )
147
+ return pe
148
+
149
+ def prepare(
150
+ self,
151
+ modality: Modality,
152
+ cross_modality: Modality | None = None, # noqa: ARG002
153
+ ) -> TransformerArgs:
154
+ x = self.patchify_proj(modality.latent)
155
+ batch_size = x.shape[0]
156
+ timestep, embedded_timestep = self._prepare_timestep(
157
+ modality.timesteps, self.adaln, batch_size, modality.latent.dtype
158
+ )
159
+ prompt_timestep = None
160
+ if self.prompt_adaln is not None:
161
+ prompt_timestep, _ = self._prepare_timestep(
162
+ modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
163
+ )
164
+ context = self._prepare_context(modality.context, x)
165
+ attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
166
+ pe = self._prepare_positional_embeddings(
167
+ positions=modality.positions,
168
+ inner_dim=self.inner_dim,
169
+ max_pos=self.max_pos,
170
+ use_middle_indices_grid=self.use_middle_indices_grid,
171
+ num_attention_heads=self.num_attention_heads,
172
+ x_dtype=modality.latent.dtype,
173
+ )
174
+ self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
175
+ return TransformerArgs(
176
+ x=x,
177
+ context=context,
178
+ context_mask=attention_mask,
179
+ timesteps=timestep,
180
+ embedded_timestep=embedded_timestep,
181
+ positional_embeddings=pe,
182
+ cross_positional_embeddings=None,
183
+ cross_scale_shift_timestep=None,
184
+ cross_gate_timestep=None,
185
+ enabled=modality.enabled,
186
+ prompt_timestep=prompt_timestep,
187
+ self_attention_mask=self_attention_mask,
188
+ )
189
+
190
+
191
+ class MultiModalTransformerArgsPreprocessor:
192
+ def __init__( # noqa: PLR0913
193
+ self,
194
+ patchify_proj: torch.nn.Linear,
195
+ adaln: AdaLayerNormSingle,
196
+ cross_scale_shift_adaln: AdaLayerNormSingle,
197
+ cross_gate_adaln: AdaLayerNormSingle,
198
+ inner_dim: int,
199
+ max_pos: list[int],
200
+ num_attention_heads: int,
201
+ cross_pe_max_pos: int,
202
+ use_middle_indices_grid: bool,
203
+ audio_cross_attention_dim: int,
204
+ timestep_scale_multiplier: int,
205
+ double_precision_rope: bool,
206
+ positional_embedding_theta: float,
207
+ rope_type: LTXRopeType,
208
+ av_ca_timestep_scale_multiplier: int,
209
+ caption_projection: torch.nn.Module | None = None,
210
+ prompt_adaln: AdaLayerNormSingle | None = None,
211
+ ) -> None:
212
+ self.simple_preprocessor = TransformerArgsPreprocessor(
213
+ patchify_proj=patchify_proj,
214
+ adaln=adaln,
215
+ inner_dim=inner_dim,
216
+ max_pos=max_pos,
217
+ num_attention_heads=num_attention_heads,
218
+ use_middle_indices_grid=use_middle_indices_grid,
219
+ timestep_scale_multiplier=timestep_scale_multiplier,
220
+ double_precision_rope=double_precision_rope,
221
+ positional_embedding_theta=positional_embedding_theta,
222
+ rope_type=rope_type,
223
+ caption_projection=caption_projection,
224
+ prompt_adaln=prompt_adaln,
225
+ )
226
+ self.cross_scale_shift_adaln = cross_scale_shift_adaln
227
+ self.cross_gate_adaln = cross_gate_adaln
228
+ self.cross_pe_max_pos = cross_pe_max_pos
229
+ self.audio_cross_attention_dim = audio_cross_attention_dim
230
+ self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
231
+
232
+ def prepare(
233
+ self,
234
+ modality: Modality,
235
+ cross_modality: Modality | None = None,
236
+ ) -> TransformerArgs:
237
+ transformer_args = self.simple_preprocessor.prepare(modality)
238
+ if cross_modality is None:
239
+ return transformer_args
240
+
241
+ if cross_modality.sigma.numel() > 1:
242
+ if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
243
+ raise ValueError("Cross modality sigma must have the same batch size as the modality")
244
+ if cross_modality.sigma.ndim != 1:
245
+ raise ValueError("Cross modality sigma must be a 1D tensor")
246
+
247
+ cross_timestep = cross_modality.sigma.view(
248
+ modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
249
+ )
250
+
251
+ cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
252
+ positions=modality.positions[:, 0:1, :],
253
+ inner_dim=self.audio_cross_attention_dim,
254
+ max_pos=[self.cross_pe_max_pos],
255
+ use_middle_indices_grid=True,
256
+ num_attention_heads=self.simple_preprocessor.num_attention_heads,
257
+ x_dtype=modality.latent.dtype,
258
+ )
259
+
260
+ cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
261
+ timestep=cross_timestep,
262
+ timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
263
+ batch_size=transformer_args.x.shape[0],
264
+ hidden_dtype=modality.latent.dtype,
265
+ )
266
+
267
+ return replace(
268
+ transformer_args,
269
+ cross_positional_embeddings=cross_pe,
270
+ cross_scale_shift_timestep=cross_scale_shift_timestep,
271
+ cross_gate_timestep=cross_gate_timestep,
272
+ )
273
+
274
+ def _prepare_cross_attention_timestep(
275
+ self,
276
+ timestep: torch.Tensor | None,
277
+ timestep_scale_multiplier: int,
278
+ batch_size: int,
279
+ hidden_dtype: torch.dtype,
280
+ ) -> tuple[torch.Tensor, torch.Tensor]:
281
+ """Prepare cross attention timestep embeddings."""
282
+ timestep = timestep * timestep_scale_multiplier
283
+
284
+ av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
285
+
286
+ scale_shift_timestep, _ = self.cross_scale_shift_adaln(
287
+ timestep.flatten(),
288
+ hidden_dtype=hidden_dtype,
289
+ )
290
+ scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1])
291
+ gate_noise_timestep, _ = self.cross_gate_adaln(
292
+ timestep.flatten() * av_ca_factor,
293
+ hidden_dtype=hidden_dtype,
294
+ )
295
+ gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1])
296
+
297
+ return scale_shift_timestep, gate_noise_timestep
packages/ltx-core/src/ltx_core/model/video_vae/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video VAE package."""
2
+
3
+ from ltx_core.model.video_vae.model_configurator import (
4
+ VAE_DECODER_COMFY_KEYS_FILTER,
5
+ VAE_ENCODER_COMFY_KEYS_FILTER,
6
+ VideoDecoderConfigurator,
7
+ VideoEncoderConfigurator,
8
+ )
9
+ from ltx_core.model.video_vae.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig
10
+ from ltx_core.model.video_vae.video_vae import VideoDecoder, VideoEncoder, decode_video, get_video_chunks_number
11
+
12
+ __all__ = [
13
+ "VAE_DECODER_COMFY_KEYS_FILTER",
14
+ "VAE_ENCODER_COMFY_KEYS_FILTER",
15
+ "SpatialTilingConfig",
16
+ "TemporalTilingConfig",
17
+ "TilingConfig",
18
+ "VideoDecoder",
19
+ "VideoDecoderConfigurator",
20
+ "VideoEncoder",
21
+ "VideoEncoderConfigurator",
22
+ "decode_video",
23
+ "get_video_chunks_number",
24
+ ]
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (811 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-312.pyc ADDED
Binary file (1.01 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-312.pyc ADDED
Binary file (4.24 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-312.pyc ADDED
Binary file (5.01 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-312.pyc ADDED
Binary file (4.96 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/tiling.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-312.pyc ADDED
Binary file (44.5 kB). View file
 
packages/ltx-core/src/ltx_core/model/video_vae/convolution.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from ltx_core.model.video_vae.enums import PaddingModeType
9
+
10
+
11
+ def make_conv_nd( # noqa: PLR0913
12
+ dims: Union[int, Tuple[int, int]],
13
+ in_channels: int,
14
+ out_channels: int,
15
+ kernel_size: int,
16
+ stride: int = 1,
17
+ padding: int = 0,
18
+ dilation: int = 1,
19
+ groups: int = 1,
20
+ bias: bool = True,
21
+ causal: bool = False,
22
+ spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
23
+ temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
24
+ ) -> nn.Module:
25
+ if not (spatial_padding_mode == temporal_padding_mode or causal):
26
+ raise NotImplementedError("spatial and temporal padding modes must be equal")
27
+ if dims == 2:
28
+ return nn.Conv2d(
29
+ in_channels=in_channels,
30
+ out_channels=out_channels,
31
+ kernel_size=kernel_size,
32
+ stride=stride,
33
+ padding=padding,
34
+ dilation=dilation,
35
+ groups=groups,
36
+ bias=bias,
37
+ padding_mode=spatial_padding_mode.value,
38
+ )
39
+ elif dims == 3:
40
+ if causal:
41
+ return CausalConv3d(
42
+ in_channels=in_channels,
43
+ out_channels=out_channels,
44
+ kernel_size=kernel_size,
45
+ stride=stride,
46
+ dilation=dilation,
47
+ groups=groups,
48
+ bias=bias,
49
+ spatial_padding_mode=spatial_padding_mode,
50
+ )
51
+ return nn.Conv3d(
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=stride,
56
+ padding=padding,
57
+ dilation=dilation,
58
+ groups=groups,
59
+ bias=bias,
60
+ padding_mode=spatial_padding_mode.value,
61
+ )
62
+ elif dims == (2, 1):
63
+ return DualConv3d(
64
+ in_channels=in_channels,
65
+ out_channels=out_channels,
66
+ kernel_size=kernel_size,
67
+ stride=stride,
68
+ padding=padding,
69
+ bias=bias,
70
+ padding_mode=spatial_padding_mode.value,
71
+ )
72
+ else:
73
+ raise ValueError(f"unsupported dimensions: {dims}")
74
+
75
+
76
+ def make_linear_nd(
77
+ dims: int,
78
+ in_channels: int,
79
+ out_channels: int,
80
+ bias: bool = True,
81
+ ) -> nn.Module:
82
+ if dims == 2:
83
+ return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
84
+ elif dims in (3, (2, 1)):
85
+ return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
86
+ else:
87
+ raise ValueError(f"unsupported dimensions: {dims}")
88
+
89
+
90
+ class DualConv3d(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ kernel_size: int,
96
+ stride: Union[int, Tuple[int, int, int]] = 1,
97
+ padding: Union[int, Tuple[int, int, int]] = 0,
98
+ dilation: Union[int, Tuple[int, int, int]] = 1,
99
+ groups: int = 1,
100
+ bias: bool = True,
101
+ padding_mode: str = "zeros",
102
+ ) -> None:
103
+ super(DualConv3d, self).__init__()
104
+
105
+ self.in_channels = in_channels
106
+ self.out_channels = out_channels
107
+ self.padding_mode = padding_mode
108
+ # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
109
+ if isinstance(kernel_size, int):
110
+ kernel_size = (kernel_size, kernel_size, kernel_size)
111
+ if kernel_size == (1, 1, 1):
112
+ raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.")
113
+ if isinstance(stride, int):
114
+ stride = (stride, stride, stride)
115
+ if isinstance(padding, int):
116
+ padding = (padding, padding, padding)
117
+ if isinstance(dilation, int):
118
+ dilation = (dilation, dilation, dilation)
119
+
120
+ # Set parameters for convolutions
121
+ self.groups = groups
122
+ self.bias = bias
123
+
124
+ # Define the size of the channels after the first convolution
125
+ intermediate_channels = out_channels if in_channels < out_channels else in_channels
126
+
127
+ # Define parameters for the first convolution
128
+ self.weight1 = nn.Parameter(
129
+ torch.Tensor(
130
+ intermediate_channels,
131
+ in_channels // groups,
132
+ 1,
133
+ kernel_size[1],
134
+ kernel_size[2],
135
+ )
136
+ )
137
+ self.stride1 = (1, stride[1], stride[2])
138
+ self.padding1 = (0, padding[1], padding[2])
139
+ self.dilation1 = (1, dilation[1], dilation[2])
140
+ if bias:
141
+ self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
142
+ else:
143
+ self.register_parameter("bias1", None)
144
+
145
+ # Define parameters for the second convolution
146
+ self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1))
147
+ self.stride2 = (stride[0], 1, 1)
148
+ self.padding2 = (padding[0], 0, 0)
149
+ self.dilation2 = (dilation[0], 1, 1)
150
+ if bias:
151
+ self.bias2 = nn.Parameter(torch.Tensor(out_channels))
152
+ else:
153
+ self.register_parameter("bias2", None)
154
+
155
+ # Initialize weights and biases
156
+ self.reset_parameters()
157
+
158
+ def reset_parameters(self) -> None:
159
+ nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5))
160
+ nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5))
161
+ if self.bias:
162
+ fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
163
+ bound1 = 1 / torch.sqrt(fan_in1)
164
+ nn.init.uniform_(self.bias1, -bound1, bound1)
165
+ fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
166
+ bound2 = 1 / torch.sqrt(fan_in2)
167
+ nn.init.uniform_(self.bias2, -bound2, bound2)
168
+
169
+ def forward(
170
+ self,
171
+ x: torch.Tensor,
172
+ use_conv3d: bool = False,
173
+ skip_time_conv: bool = False,
174
+ ) -> torch.Tensor:
175
+ if use_conv3d:
176
+ return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
177
+ else:
178
+ return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
179
+
180
+ def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:
181
+ # First convolution
182
+ x = F.conv3d(
183
+ x,
184
+ self.weight1,
185
+ self.bias1,
186
+ self.stride1,
187
+ self.padding1,
188
+ self.dilation1,
189
+ self.groups,
190
+ padding_mode=self.padding_mode,
191
+ )
192
+
193
+ if skip_time_conv:
194
+ return x
195
+
196
+ # Second convolution
197
+ x = F.conv3d(
198
+ x,
199
+ self.weight2,
200
+ self.bias2,
201
+ self.stride2,
202
+ self.padding2,
203
+ self.dilation2,
204
+ self.groups,
205
+ padding_mode=self.padding_mode,
206
+ )
207
+
208
+ return x
209
+
210
+ def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:
211
+ b, _, _, h, w = x.shape
212
+
213
+ # First 2D convolution
214
+ x = rearrange(x, "b c d h w -> (b d) c h w")
215
+ # Squeeze the depth dimension out of weight1 since it's 1
216
+ weight1 = self.weight1.squeeze(2)
217
+ # Select stride, padding, and dilation for the 2D convolution
218
+ stride1 = (self.stride1[1], self.stride1[2])
219
+ padding1 = (self.padding1[1], self.padding1[2])
220
+ dilation1 = (self.dilation1[1], self.dilation1[2])
221
+ x = F.conv2d(
222
+ x,
223
+ weight1,
224
+ self.bias1,
225
+ stride1,
226
+ padding1,
227
+ dilation1,
228
+ self.groups,
229
+ padding_mode=self.padding_mode,
230
+ )
231
+
232
+ _, _, h, w = x.shape
233
+
234
+ if skip_time_conv:
235
+ x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
236
+ return x
237
+
238
+ # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
239
+ x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
240
+
241
+ # Reshape weight2 to match the expected dimensions for conv1d
242
+ weight2 = self.weight2.squeeze(-1).squeeze(-1)
243
+ # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
244
+ stride2 = self.stride2[0]
245
+ padding2 = self.padding2[0]
246
+ dilation2 = self.dilation2[0]
247
+ x = F.conv1d(
248
+ x,
249
+ weight2,
250
+ self.bias2,
251
+ stride2,
252
+ padding2,
253
+ dilation2,
254
+ self.groups,
255
+ padding_mode=self.padding_mode,
256
+ )
257
+ x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
258
+
259
+ return x
260
+
261
+ @property
262
+ def weight(self) -> torch.Tensor:
263
+ return self.weight2
264
+
265
+
266
+ class CausalConv3d(nn.Module):
267
+ def __init__(
268
+ self,
269
+ in_channels: int,
270
+ out_channels: int,
271
+ kernel_size: int = 3,
272
+ stride: Union[int, Tuple[int]] = 1,
273
+ dilation: int = 1,
274
+ groups: int = 1,
275
+ bias: bool = True,
276
+ spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
277
+ ) -> None:
278
+ super().__init__()
279
+
280
+ self.in_channels = in_channels
281
+ self.out_channels = out_channels
282
+
283
+ kernel_size = (kernel_size, kernel_size, kernel_size)
284
+ self.time_kernel_size = kernel_size[0]
285
+
286
+ dilation = (dilation, 1, 1)
287
+
288
+ height_pad = kernel_size[1] // 2
289
+ width_pad = kernel_size[2] // 2
290
+ padding = (0, height_pad, width_pad)
291
+
292
+ self.conv = nn.Conv3d(
293
+ in_channels,
294
+ out_channels,
295
+ kernel_size,
296
+ stride=stride,
297
+ dilation=dilation,
298
+ padding=padding,
299
+ padding_mode=spatial_padding_mode.value,
300
+ groups=groups,
301
+ bias=bias,
302
+ )
303
+
304
+ def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
305
+ if causal:
306
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
307
+ x = torch.concatenate((first_frame_pad, x), dim=2)
308
+ else:
309
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
310
+ last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
311
+ x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
312
+ x = self.conv(x)
313
+ return x
314
+
315
+ @property
316
+ def weight(self) -> torch.Tensor:
317
+ return self.conv.weight
packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ltx_core.loader.sd_ops import SDOps
2
+ from ltx_core.model.model_protocol import ModelConfigurator
3
+ from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType
4
+ from ltx_core.model.video_vae.video_vae import VideoDecoder, VideoEncoder
5
+
6
+
7
+ class VideoEncoderConfigurator(ModelConfigurator[VideoEncoder]):
8
+ """Configurator for creating a video VAE Encoder from a configuration dictionary."""
9
+
10
+ @classmethod
11
+ def from_config(cls: type[VideoEncoder], config: dict) -> VideoEncoder:
12
+ config = config.get("vae", {})
13
+ convolution_dimensions = config.get("dims", 3)
14
+ in_channels = config.get("in_channels", 3)
15
+ latent_channels = config.get("latent_channels", 128)
16
+ spatial_padding_mode = PaddingModeType(config.get("spatial_padding_mode", "zeros"))
17
+ encoder_blocks = config.get("encoder_blocks", [])
18
+ patch_size = config.get("patch_size", 4)
19
+ norm_layer_str = config.get("norm_layer", "pixel_norm")
20
+ latent_log_var_str = config.get("latent_log_var", "uniform")
21
+
22
+ return VideoEncoder(
23
+ convolution_dimensions=convolution_dimensions,
24
+ in_channels=in_channels,
25
+ out_channels=latent_channels,
26
+ encoder_blocks=encoder_blocks,
27
+ patch_size=patch_size,
28
+ norm_layer=NormLayerType(norm_layer_str),
29
+ latent_log_var=LogVarianceType(latent_log_var_str),
30
+ encoder_spatial_padding_mode=spatial_padding_mode,
31
+ )
32
+
33
+
34
+ class VideoDecoderConfigurator(ModelConfigurator[VideoDecoder]):
35
+ """Configurator for creating a video VAE Decoder from a configuration dictionary."""
36
+
37
+ @classmethod
38
+ def from_config(cls: type[VideoDecoder], config: dict) -> VideoDecoder:
39
+ config = config.get("vae", {})
40
+ convolution_dimensions = config.get("dims", 3)
41
+ latent_channels = config.get("latent_channels", 128)
42
+ spatial_padding_mode = PaddingModeType(config.get("spatial_padding_mode", "reflect"))
43
+ out_channels = config.get("out_channels", 3)
44
+ decoder_blocks = config.get("decoder_blocks", [])
45
+ patch_size = config.get("patch_size", 4)
46
+ norm_layer_str = config.get("norm_layer", "pixel_norm")
47
+ causal = config.get("causal_decoder", False)
48
+ timestep_conditioning = config.get("timestep_conditioning", True)
49
+ base_channels = config.get("decoder_base_channels", 128)
50
+
51
+ return VideoDecoder(
52
+ convolution_dimensions=convolution_dimensions,
53
+ in_channels=latent_channels,
54
+ out_channels=out_channels,
55
+ decoder_blocks=decoder_blocks,
56
+ patch_size=patch_size,
57
+ norm_layer=NormLayerType(norm_layer_str),
58
+ causal=causal,
59
+ timestep_conditioning=timestep_conditioning,
60
+ decoder_spatial_padding_mode=spatial_padding_mode,
61
+ base_channels=base_channels,
62
+ )
63
+
64
+
65
+ VAE_DECODER_COMFY_KEYS_FILTER = (
66
+ SDOps("VAE_DECODER_COMFY_KEYS_FILTER")
67
+ .with_matching(prefix="vae.decoder.")
68
+ .with_matching(prefix="vae.per_channel_statistics.")
69
+ .with_replacement("vae.decoder.", "")
70
+ .with_replacement("vae.per_channel_statistics.", "per_channel_statistics.")
71
+ )
72
+
73
+ VAE_ENCODER_COMFY_KEYS_FILTER = (
74
+ SDOps("VAE_ENCODER_COMFY_KEYS_FILTER")
75
+ .with_matching(prefix="vae.encoder.")
76
+ .with_matching(prefix="vae.per_channel_statistics.")
77
+ .with_replacement("vae.encoder.", "")
78
+ .with_replacement("vae.per_channel_statistics.", "per_channel_statistics.")
79
+ )
packages/ltx-core/src/ltx_core/model/video_vae/resnet.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from ltx_core.model.common.normalization import PixelNorm
7
+ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
8
+ from ltx_core.model.video_vae.convolution import make_conv_nd, make_linear_nd
9
+ from ltx_core.model.video_vae.enums import NormLayerType, PaddingModeType
10
+
11
+
12
+ class ResnetBlock3D(nn.Module):
13
+ r"""
14
+ A Resnet block.
15
+ Parameters:
16
+ in_channels (`int`): The number of channels in the input.
17
+ out_channels (`int`, *optional*, default to be `None`):
18
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
19
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
20
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
21
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ dims: Union[int, Tuple[int, int]],
27
+ in_channels: int,
28
+ out_channels: Optional[int] = None,
29
+ dropout: float = 0.0,
30
+ groups: int = 32,
31
+ eps: float = 1e-6,
32
+ norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
33
+ inject_noise: bool = False,
34
+ timestep_conditioning: bool = False,
35
+ spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
36
+ ):
37
+ super().__init__()
38
+ self.in_channels = in_channels
39
+ out_channels = in_channels if out_channels is None else out_channels
40
+ self.out_channels = out_channels
41
+ self.inject_noise = inject_noise
42
+
43
+ if norm_layer == NormLayerType.GROUP_NORM:
44
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
45
+ elif norm_layer == NormLayerType.PIXEL_NORM:
46
+ self.norm1 = PixelNorm()
47
+
48
+ self.non_linearity = nn.SiLU()
49
+
50
+ self.conv1 = make_conv_nd(
51
+ dims,
52
+ in_channels,
53
+ out_channels,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ causal=True,
58
+ spatial_padding_mode=spatial_padding_mode,
59
+ )
60
+
61
+ if inject_noise:
62
+ self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
63
+
64
+ if norm_layer == NormLayerType.GROUP_NORM:
65
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
66
+ elif norm_layer == NormLayerType.PIXEL_NORM:
67
+ self.norm2 = PixelNorm()
68
+
69
+ self.dropout = torch.nn.Dropout(dropout)
70
+
71
+ self.conv2 = make_conv_nd(
72
+ dims,
73
+ out_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ stride=1,
77
+ padding=1,
78
+ causal=True,
79
+ spatial_padding_mode=spatial_padding_mode,
80
+ )
81
+
82
+ if inject_noise:
83
+ self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
84
+
85
+ self.conv_shortcut = (
86
+ make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
87
+ if in_channels != out_channels
88
+ else nn.Identity()
89
+ )
90
+
91
+ # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout
92
+ # avoiding the need for dimension rearrangement used in standard nn.LayerNorm
93
+ self.norm3 = (
94
+ nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True)
95
+ if in_channels != out_channels
96
+ else nn.Identity()
97
+ )
98
+
99
+ self.timestep_conditioning = timestep_conditioning
100
+
101
+ if timestep_conditioning:
102
+ self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels))
103
+
104
+ def _feed_spatial_noise(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ per_channel_scale: torch.Tensor,
108
+ generator: Optional[torch.Generator] = None,
109
+ ) -> torch.Tensor:
110
+ spatial_shape = hidden_states.shape[-2:]
111
+ device = hidden_states.device
112
+ dtype = hidden_states.dtype
113
+
114
+ # similar to the "explicit noise inputs" method in style-gan
115
+ spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None]
116
+ scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
117
+ hidden_states = hidden_states + scaled_noise
118
+
119
+ return hidden_states
120
+
121
+ def forward(
122
+ self,
123
+ input_tensor: torch.Tensor,
124
+ causal: bool = True,
125
+ timestep: Optional[torch.Tensor] = None,
126
+ generator: Optional[torch.Generator] = None,
127
+ ) -> torch.Tensor:
128
+ hidden_states = input_tensor
129
+ batch_size = hidden_states.shape[0]
130
+
131
+ hidden_states = self.norm1(hidden_states)
132
+ if self.timestep_conditioning:
133
+ if timestep is None:
134
+ raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
135
+ ada_values = self.scale_shift_table[None, ..., None, None, None].to(
136
+ device=hidden_states.device, dtype=hidden_states.dtype
137
+ ) + timestep.reshape(
138
+ batch_size,
139
+ 4,
140
+ -1,
141
+ timestep.shape[-3],
142
+ timestep.shape[-2],
143
+ timestep.shape[-1],
144
+ )
145
+ shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
146
+
147
+ hidden_states = hidden_states * (1 + scale1) + shift1
148
+
149
+ hidden_states = self.non_linearity(hidden_states)
150
+
151
+ hidden_states = self.conv1(hidden_states, causal=causal)
152
+
153
+ if self.inject_noise:
154
+ hidden_states = self._feed_spatial_noise(
155
+ hidden_states,
156
+ self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype),
157
+ generator=generator,
158
+ )
159
+
160
+ hidden_states = self.norm2(hidden_states)
161
+
162
+ if self.timestep_conditioning:
163
+ hidden_states = hidden_states * (1 + scale2) + shift2
164
+
165
+ hidden_states = self.non_linearity(hidden_states)
166
+
167
+ hidden_states = self.dropout(hidden_states)
168
+
169
+ hidden_states = self.conv2(hidden_states, causal=causal)
170
+
171
+ if self.inject_noise:
172
+ hidden_states = self._feed_spatial_noise(
173
+ hidden_states,
174
+ self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype),
175
+ generator=generator,
176
+ )
177
+
178
+ input_tensor = self.norm3(input_tensor)
179
+
180
+ batch_size = input_tensor.shape[0]
181
+
182
+ input_tensor = self.conv_shortcut(input_tensor)
183
+
184
+ output_tensor = input_tensor + hidden_states
185
+
186
+ return output_tensor
187
+
188
+
189
+ class UNetMidBlock3D(nn.Module):
190
+ """
191
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
192
+ Args:
193
+ in_channels (`int`): The number of input channels.
194
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
195
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
196
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
197
+ resnet_groups (`int`, *optional*, defaults to 32):
198
+ The number of groups to use in the group normalization layers of the resnet blocks.
199
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
200
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
201
+ inject_noise (`bool`, *optional*, defaults to `False`):
202
+ Whether to inject noise into the hidden states.
203
+ timestep_conditioning (`bool`, *optional*, defaults to `False`):
204
+ Whether to condition the hidden states on the timestep.
205
+ Returns:
206
+ `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
207
+ in_channels, height, width)`.
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ dims: Union[int, Tuple[int, int]],
213
+ in_channels: int,
214
+ dropout: float = 0.0,
215
+ num_layers: int = 1,
216
+ resnet_eps: float = 1e-6,
217
+ resnet_groups: int = 32,
218
+ norm_layer: NormLayerType = NormLayerType.GROUP_NORM,
219
+ inject_noise: bool = False,
220
+ timestep_conditioning: bool = False,
221
+ spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
222
+ ):
223
+ super().__init__()
224
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
225
+
226
+ self.timestep_conditioning = timestep_conditioning
227
+
228
+ if timestep_conditioning:
229
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
230
+ embedding_dim=in_channels * 4, size_emb_dim=0
231
+ )
232
+
233
+ self.res_blocks = nn.ModuleList(
234
+ [
235
+ ResnetBlock3D(
236
+ dims=dims,
237
+ in_channels=in_channels,
238
+ out_channels=in_channels,
239
+ eps=resnet_eps,
240
+ groups=resnet_groups,
241
+ dropout=dropout,
242
+ norm_layer=norm_layer,
243
+ inject_noise=inject_noise,
244
+ timestep_conditioning=timestep_conditioning,
245
+ spatial_padding_mode=spatial_padding_mode,
246
+ )
247
+ for _ in range(num_layers)
248
+ ]
249
+ )
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ causal: bool = True,
255
+ timestep: Optional[torch.Tensor] = None,
256
+ generator: Optional[torch.Generator] = None,
257
+ ) -> torch.Tensor:
258
+ timestep_embed = None
259
+ if self.timestep_conditioning:
260
+ if timestep is None:
261
+ raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
262
+ batch_size = hidden_states.shape[0]
263
+ timestep_embed = self.time_embedder(
264
+ timestep=timestep.flatten(),
265
+ hidden_dtype=hidden_states.dtype,
266
+ )
267
+ timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1)
268
+
269
+ for resnet in self.res_blocks:
270
+ hidden_states = resnet(
271
+ hidden_states,
272
+ causal=causal,
273
+ timestep=timestep_embed,
274
+ generator=generator,
275
+ )
276
+
277
+ return hidden_states
packages/ltx-core/src/ltx_core/model/video_vae/tiling.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, NamedTuple, Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def compute_trapezoidal_mask_1d(
9
+ length: int,
10
+ ramp_left: int,
11
+ ramp_right: int,
12
+ left_starts_from_0: bool = False,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Generate a 1D trapezoidal blending mask with linear ramps.
16
+ Args:
17
+ length: Output length of the mask.
18
+ ramp_left: Fade-in length on the left.
19
+ ramp_right: Fade-out length on the right.
20
+ left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
21
+ Useful for temporal tiles where the first tile is causal.
22
+ Returns:
23
+ A 1D tensor of shape `(length,)` with values in [0, 1].
24
+ """
25
+ if length <= 0:
26
+ raise ValueError("Mask length must be positive.")
27
+
28
+ ramp_left = max(0, min(ramp_left, length))
29
+ ramp_right = max(0, min(ramp_right, length))
30
+
31
+ mask = torch.ones(length)
32
+
33
+ if ramp_left > 0:
34
+ interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
35
+ fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1]
36
+ if not left_starts_from_0:
37
+ fade_in = fade_in[1:]
38
+ mask[:ramp_left] *= fade_in
39
+
40
+ if ramp_right > 0:
41
+ fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1]
42
+ mask[-ramp_right:] *= fade_out
43
+
44
+ return mask.clamp_(0, 1)
45
+
46
+
47
+ def compute_rectangular_mask_1d(
48
+ length: int,
49
+ left_ramp: int,
50
+ right_ramp: int,
51
+ ) -> torch.Tensor:
52
+ """
53
+ Generate a 1D rectangular (pulse) mask.
54
+ Args:
55
+ length: Output length of the mask.
56
+ left_ramp: Number of elements at the start of the mask to set to 0.
57
+ right_ramp: Number of elements at the end of the mask to set to 0.
58
+ Returns:
59
+ A 1D tensor of shape `(length,)` with values 0 or 1.
60
+ """
61
+ if length <= 0:
62
+ raise ValueError("Mask length must be positive.")
63
+
64
+ mask = torch.ones(length)
65
+ if left_ramp > 0:
66
+ mask[:left_ramp] = 0
67
+ if right_ramp > 0:
68
+ mask[-right_ramp:] = 0
69
+ return mask
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class SpatialTilingConfig:
74
+ """Configuration for dividing each frame into spatial tiles with optional overlap.
75
+ Args:
76
+ tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32.
77
+ tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0.
78
+ """
79
+
80
+ tile_size_in_pixels: int
81
+ tile_overlap_in_pixels: int = 0
82
+
83
+ def __post_init__(self) -> None:
84
+ if self.tile_size_in_pixels < 64:
85
+ raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
86
+ if self.tile_size_in_pixels % 32 != 0:
87
+ raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
88
+ if self.tile_overlap_in_pixels % 32 != 0:
89
+ raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
90
+ if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
91
+ raise ValueError(
92
+ f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
93
+ )
94
+
95
+
96
+ @dataclass(frozen=True)
97
+ class TemporalTilingConfig:
98
+ """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap.
99
+ Args:
100
+ tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8.
101
+ tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles.
102
+ Must be divisible by 8. Defaults to 0.
103
+ """
104
+
105
+ tile_size_in_frames: int
106
+ tile_overlap_in_frames: int = 0
107
+
108
+ def __post_init__(self) -> None:
109
+ if self.tile_size_in_frames < 16:
110
+ raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
111
+ if self.tile_size_in_frames % 8 != 0:
112
+ raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
113
+ if self.tile_overlap_in_frames % 8 != 0:
114
+ raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
115
+ if self.tile_overlap_in_frames >= self.tile_size_in_frames:
116
+ raise ValueError(
117
+ f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
118
+ )
119
+
120
+
121
+ @dataclass(frozen=True)
122
+ class TilingConfig:
123
+ """Configuration for splitting video into tiles with optional overlap.
124
+ Attributes:
125
+ spatial_config: Configuration for splitting spatial dimensions into tiles.
126
+ temporal_config: Configuration for splitting temporal dimension into tiles.
127
+ """
128
+
129
+ spatial_config: SpatialTilingConfig | None = None
130
+ temporal_config: TemporalTilingConfig | None = None
131
+
132
+ @classmethod
133
+ def default(cls) -> "TilingConfig":
134
+ return cls(
135
+ spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
136
+ temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
137
+ )
138
+
139
+
140
+ @dataclass(frozen=True)
141
+ class DimensionIntervals:
142
+ """Defines how a single dimension is split into overlapping intervals (tiles).
143
+ Each list has length N where N is the number of intervals. The i-th element
144
+ of each list describes the i-th interval.
145
+ Attributes:
146
+ starts: Start index of each interval (inclusive).
147
+ ends: End index of each interval (exclusive).
148
+ left_ramps: Length of the left blend ramp for each interval.
149
+ Used to create masks that fade in from 0 to 1.
150
+ right_ramps: Length of the right blend ramp for each interval.
151
+ Used to create masks that fade out from 1 to 0.
152
+ """
153
+
154
+ starts: List[int]
155
+ ends: List[int]
156
+ left_ramps: List[int]
157
+ right_ramps: List[int]
158
+
159
+
160
+ @dataclass(frozen=True)
161
+ class TensorTilingSpec:
162
+ """Specifies how a tensor of a given shape is split into intervals (tiles) along each dimension.
163
+ Attributes:
164
+ original_shape: Shape of the tensor being tiled.
165
+ dimension_intervals: Per-dimension intervals (starts, ends, ramps) for each axis.
166
+ """
167
+
168
+ original_shape: torch.Size
169
+ dimension_intervals: Tuple[DimensionIntervals, ...]
170
+
171
+
172
+ # Operation to split a single dimension of the tensor into intervals based on the length along the dimension.
173
+ SplitOperation = Callable[[int], DimensionIntervals]
174
+ # Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension.
175
+ MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]]
176
+
177
+
178
+ def default_split_operation(length: int) -> DimensionIntervals:
179
+ return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0])
180
+
181
+
182
+ DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation
183
+
184
+
185
+ def default_mapping_operation(
186
+ _intervals: DimensionIntervals,
187
+ ) -> tuple[list[slice], list[torch.Tensor | None]]:
188
+ return [slice(0, None)], [None]
189
+
190
+
191
+ DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation
192
+
193
+
194
+ class Tile(NamedTuple):
195
+ """
196
+ Represents a single tile.
197
+ Attributes:
198
+ in_coords:
199
+ Tuple of slices specifying where to cut the tile from the INPUT tensor.
200
+ out_coords:
201
+ Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor.
202
+ masks_1d:
203
+ Per-dimension masks in OUTPUT units.
204
+ These are used to create all-dimensional blending mask.
205
+ Methods:
206
+ blend_mask:
207
+ Create a single N-D mask from the per-dimension masks.
208
+ """
209
+
210
+ in_coords: Tuple[slice, ...]
211
+ out_coords: Tuple[slice, ...]
212
+ masks_1d: Tuple[Tuple[torch.Tensor, ...]]
213
+
214
+ @property
215
+ def blend_mask(self) -> torch.Tensor:
216
+ num_dims = len(self.out_coords)
217
+ per_dimension_masks: List[torch.Tensor] = []
218
+
219
+ for dim_idx in range(num_dims):
220
+ mask_1d = self.masks_1d[dim_idx]
221
+ view_shape = [1] * num_dims
222
+ if mask_1d is None:
223
+ # Broadcast mask along this dimension (length 1).
224
+ one = torch.ones(1)
225
+
226
+ view_shape[dim_idx] = 1
227
+ per_dimension_masks.append(one.view(*view_shape))
228
+ continue
229
+
230
+ # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply.
231
+ view_shape[dim_idx] = mask_1d.shape[0]
232
+ per_dimension_masks.append(mask_1d.view(*view_shape))
233
+
234
+ # Multiply per-dimension masks to form the full N-D mask (separable blending window).
235
+ combined_mask = per_dimension_masks[0]
236
+ for mask in per_dimension_masks[1:]:
237
+ combined_mask = combined_mask * mask
238
+
239
+ return combined_mask
240
+
241
+
242
+ def create_tiles_from_intervals_and_mappers(
243
+ intervals: TensorTilingSpec,
244
+ mappers: List[MappingOperation],
245
+ ) -> List[Tile]:
246
+ full_dim_input_slices = []
247
+ full_dim_output_slices = []
248
+ full_dim_masks_1d = []
249
+ for axis_index in range(len(intervals.original_shape)):
250
+ dimension_intervals = intervals.dimension_intervals[axis_index]
251
+ starts = dimension_intervals.starts
252
+ ends = dimension_intervals.ends
253
+ input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)]
254
+ output_slices, masks_1d = mappers[axis_index](dimension_intervals)
255
+ full_dim_input_slices.append(input_slices)
256
+ full_dim_output_slices.append(output_slices)
257
+ full_dim_masks_1d.append(masks_1d)
258
+
259
+ tiles = []
260
+ tile_in_coords = list(itertools.product(*full_dim_input_slices))
261
+ tile_out_coords = list(itertools.product(*full_dim_output_slices))
262
+ tile_mask_1ds = list(itertools.product(*full_dim_masks_1d))
263
+ for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True):
264
+ tiles.append(
265
+ Tile(
266
+ in_coords=in_coord,
267
+ out_coords=out_coord,
268
+ masks_1d=mask_1d,
269
+ )
270
+ )
271
+ return tiles
272
+
273
+
274
+ def create_tiles(
275
+ tensor_shape: torch.Size,
276
+ splitters: List[SplitOperation],
277
+ mappers: List[MappingOperation],
278
+ ) -> List[Tile]:
279
+ if len(splitters) != len(tensor_shape):
280
+ raise ValueError(
281
+ f"Number of splitters must be equal to number of dimensions in tensor shape, "
282
+ f"got {len(splitters)} and {len(tensor_shape)}"
283
+ )
284
+ if len(mappers) != len(tensor_shape):
285
+ raise ValueError(
286
+ f"Number of mappers must be equal to number of dimensions in tensor shape, "
287
+ f"got {len(mappers)} and {len(tensor_shape)}"
288
+ )
289
+ intervals = [splitter(length) for splitter, length in zip(splitters, tensor_shape, strict=True)]
290
+ tiling_spec = TensorTilingSpec(original_shape=tensor_shape, dimension_intervals=tuple(intervals))
291
+ return create_tiles_from_intervals_and_mappers(tiling_spec, mappers)
packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py ADDED
@@ -0,0 +1,1219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import replace
3
+ from typing import Any, Callable, Iterator, List, Tuple
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+
9
+ from ltx_core.model.common.normalization import PixelNorm
10
+ from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
11
+ from ltx_core.model.video_vae.convolution import make_conv_nd
12
+ from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType
13
+ from ltx_core.model.video_vae.ops import PerChannelStatistics, patchify, unpatchify
14
+ from ltx_core.model.video_vae.resnet import ResnetBlock3D, UNetMidBlock3D
15
+ from ltx_core.model.video_vae.sampling import DepthToSpaceUpsample, SpaceToDepthDownsample
16
+ from ltx_core.model.video_vae.tiling import (
17
+ DEFAULT_MAPPING_OPERATION,
18
+ DEFAULT_SPLIT_OPERATION,
19
+ DimensionIntervals,
20
+ MappingOperation,
21
+ SplitOperation,
22
+ Tile,
23
+ TilingConfig,
24
+ compute_rectangular_mask_1d,
25
+ compute_trapezoidal_mask_1d,
26
+ create_tiles,
27
+ )
28
+ from ltx_core.types import VIDEO_SCALE_FACTORS, SpatioTemporalScaleFactors, VideoLatentShape
29
+
30
+ logger: logging.Logger = logging.getLogger(__name__)
31
+
32
+
33
+ def _make_encoder_block(
34
+ block_name: str,
35
+ block_config: dict[str, Any],
36
+ in_channels: int,
37
+ convolution_dimensions: int,
38
+ norm_layer: NormLayerType,
39
+ norm_num_groups: int,
40
+ spatial_padding_mode: PaddingModeType,
41
+ ) -> Tuple[nn.Module, int]:
42
+ out_channels = in_channels
43
+
44
+ if block_name == "res_x":
45
+ block = UNetMidBlock3D(
46
+ dims=convolution_dimensions,
47
+ in_channels=in_channels,
48
+ num_layers=block_config["num_layers"],
49
+ resnet_eps=1e-6,
50
+ resnet_groups=norm_num_groups,
51
+ norm_layer=norm_layer,
52
+ spatial_padding_mode=spatial_padding_mode,
53
+ )
54
+ elif block_name == "res_x_y":
55
+ out_channels = in_channels * block_config.get("multiplier", 2)
56
+ block = ResnetBlock3D(
57
+ dims=convolution_dimensions,
58
+ in_channels=in_channels,
59
+ out_channels=out_channels,
60
+ eps=1e-6,
61
+ groups=norm_num_groups,
62
+ norm_layer=norm_layer,
63
+ spatial_padding_mode=spatial_padding_mode,
64
+ )
65
+ elif block_name == "compress_time":
66
+ block = make_conv_nd(
67
+ dims=convolution_dimensions,
68
+ in_channels=in_channels,
69
+ out_channels=out_channels,
70
+ kernel_size=3,
71
+ stride=(2, 1, 1),
72
+ causal=True,
73
+ spatial_padding_mode=spatial_padding_mode,
74
+ )
75
+ elif block_name == "compress_space":
76
+ block = make_conv_nd(
77
+ dims=convolution_dimensions,
78
+ in_channels=in_channels,
79
+ out_channels=out_channels,
80
+ kernel_size=3,
81
+ stride=(1, 2, 2),
82
+ causal=True,
83
+ spatial_padding_mode=spatial_padding_mode,
84
+ )
85
+ elif block_name == "compress_all":
86
+ block = make_conv_nd(
87
+ dims=convolution_dimensions,
88
+ in_channels=in_channels,
89
+ out_channels=out_channels,
90
+ kernel_size=3,
91
+ stride=(2, 2, 2),
92
+ causal=True,
93
+ spatial_padding_mode=spatial_padding_mode,
94
+ )
95
+ elif block_name == "compress_all_x_y":
96
+ out_channels = in_channels * block_config.get("multiplier", 2)
97
+ block = make_conv_nd(
98
+ dims=convolution_dimensions,
99
+ in_channels=in_channels,
100
+ out_channels=out_channels,
101
+ kernel_size=3,
102
+ stride=(2, 2, 2),
103
+ causal=True,
104
+ spatial_padding_mode=spatial_padding_mode,
105
+ )
106
+ elif block_name == "compress_all_res":
107
+ out_channels = in_channels * block_config.get("multiplier", 2)
108
+ block = SpaceToDepthDownsample(
109
+ dims=convolution_dimensions,
110
+ in_channels=in_channels,
111
+ out_channels=out_channels,
112
+ stride=(2, 2, 2),
113
+ spatial_padding_mode=spatial_padding_mode,
114
+ )
115
+ elif block_name == "compress_space_res":
116
+ out_channels = in_channels * block_config.get("multiplier", 2)
117
+ block = SpaceToDepthDownsample(
118
+ dims=convolution_dimensions,
119
+ in_channels=in_channels,
120
+ out_channels=out_channels,
121
+ stride=(1, 2, 2),
122
+ spatial_padding_mode=spatial_padding_mode,
123
+ )
124
+ elif block_name == "compress_time_res":
125
+ out_channels = in_channels * block_config.get("multiplier", 2)
126
+ block = SpaceToDepthDownsample(
127
+ dims=convolution_dimensions,
128
+ in_channels=in_channels,
129
+ out_channels=out_channels,
130
+ stride=(2, 1, 1),
131
+ spatial_padding_mode=spatial_padding_mode,
132
+ )
133
+ else:
134
+ raise ValueError(f"unknown block: {block_name}")
135
+
136
+ return block, out_channels
137
+
138
+
139
+ class VideoEncoder(nn.Module):
140
+ _DEFAULT_NORM_NUM_GROUPS = 32
141
+ """
142
+ Variational Autoencoder Encoder. Encodes video frames into a latent representation.
143
+ The encoder compresses the input video through a series of downsampling operations controlled by
144
+ patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W').
145
+ Compression Behavior:
146
+ The total compression is determined by:
147
+ 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4)
148
+ 2. Sequential compression through encoder_blocks based on their stride patterns
149
+ Compression blocks apply 2x compression in specified dimensions:
150
+ - "compress_time" / "compress_time_res": temporal only
151
+ - "compress_space" / "compress_space_res": spatial only (H and W)
152
+ - "compress_all" / "compress_all_res": all dimensions (F, H, W)
153
+ - "res_x" / "res_x_y": no compression
154
+ Standard LTX Video configuration:
155
+ - patch_size=4
156
+ - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res
157
+ - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32
158
+ - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16)
159
+ - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...)
160
+ Args:
161
+ convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).
162
+ in_channels: The number of input channels. For RGB images, this is 3.
163
+ out_channels: The number of output channels (latent channels). For latent channels, this is 128.
164
+ encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params)
165
+ where params is either an int (num_layers) or a dict with configuration.
166
+ patch_size: The patch size for initial spatial compression. Should be a power of 2.
167
+ norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
168
+ latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ convolution_dimensions: int = 3,
174
+ in_channels: int = 3,
175
+ out_channels: int = 128,
176
+ encoder_blocks: List[Tuple[str, int]] | List[Tuple[str, dict[str, Any]]] = [], # noqa: B006
177
+ patch_size: int = 4,
178
+ norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
179
+ latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
180
+ encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.patch_size = patch_size
185
+ self.norm_layer = norm_layer
186
+ self.latent_channels = out_channels
187
+ self.latent_log_var = latent_log_var
188
+ self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
189
+
190
+ # Per-channel statistics for normalizing latents
191
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
192
+
193
+ in_channels = in_channels * patch_size**2
194
+ feature_channels = out_channels
195
+
196
+ self.conv_in = make_conv_nd(
197
+ dims=convolution_dimensions,
198
+ in_channels=in_channels,
199
+ out_channels=feature_channels,
200
+ kernel_size=3,
201
+ stride=1,
202
+ padding=1,
203
+ causal=True,
204
+ spatial_padding_mode=encoder_spatial_padding_mode,
205
+ )
206
+
207
+ self.down_blocks = nn.ModuleList([])
208
+
209
+ for block_name, block_params in encoder_blocks:
210
+ # Convert int to dict format for uniform handling
211
+ block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
212
+
213
+ block, feature_channels = _make_encoder_block(
214
+ block_name=block_name,
215
+ block_config=block_config,
216
+ in_channels=feature_channels,
217
+ convolution_dimensions=convolution_dimensions,
218
+ norm_layer=norm_layer,
219
+ norm_num_groups=self._norm_num_groups,
220
+ spatial_padding_mode=encoder_spatial_padding_mode,
221
+ )
222
+
223
+ self.down_blocks.append(block)
224
+
225
+ # out
226
+ if norm_layer == NormLayerType.GROUP_NORM:
227
+ self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)
228
+ elif norm_layer == NormLayerType.PIXEL_NORM:
229
+ self.conv_norm_out = PixelNorm()
230
+
231
+ self.conv_act = nn.SiLU()
232
+
233
+ conv_out_channels = out_channels
234
+ if latent_log_var == LogVarianceType.PER_CHANNEL:
235
+ conv_out_channels *= 2
236
+ elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
237
+ conv_out_channels += 1
238
+ elif latent_log_var != LogVarianceType.NONE:
239
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
240
+
241
+ self.conv_out = make_conv_nd(
242
+ dims=convolution_dimensions,
243
+ in_channels=feature_channels,
244
+ out_channels=conv_out_channels,
245
+ kernel_size=3,
246
+ padding=1,
247
+ causal=True,
248
+ spatial_padding_mode=encoder_spatial_padding_mode,
249
+ )
250
+
251
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
252
+ r"""
253
+ Encode video frames into normalized latent representation.
254
+ Args:
255
+ sample: Input video (B, C, F, H, W). F should be 1 + 8*k (e.g., 1, 9, 17, 25, 33...).
256
+ If not, the encoder crops the last frames to the nearest valid length.
257
+ Returns:
258
+ Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32.
259
+ Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16).
260
+ """
261
+ # Validate frame count (crop to nearest valid length if needed)
262
+ frames_count = sample.shape[2]
263
+ if ((frames_count - 1) % 8) != 0:
264
+ frames_to_crop = (frames_count - 1) % 8
265
+ logger.warning(
266
+ "Invalid number of frames %s for encode; cropping last %s frames to satisfy 1 + 8*k.",
267
+ frames_count,
268
+ frames_to_crop,
269
+ )
270
+ sample = sample[:, :, :-frames_to_crop, ...]
271
+
272
+ # Initial spatial compression: trade spatial resolution for channel depth
273
+ # This reduces H,W by patch_size and increases channels, making convolutions more efficient
274
+ # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4
275
+ sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
276
+ sample = self.conv_in(sample)
277
+
278
+ for down_block in self.down_blocks:
279
+ sample = down_block(sample)
280
+
281
+ sample = self.conv_norm_out(sample)
282
+ sample = self.conv_act(sample)
283
+ sample = self.conv_out(sample)
284
+
285
+ if self.latent_log_var == LogVarianceType.UNIFORM:
286
+ # Uniform Variance: model outputs N means and 1 shared log-variance channel.
287
+ # We need to expand the single logvar to match the number of means channels
288
+ # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels).
289
+ # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129)
290
+ # Target shape: (B, 2*N, ...) where first N are means, last N are logvar
291
+
292
+ if sample.shape[1] < 2:
293
+ raise ValueError(
294
+ f"Invalid channel count for UNIFORM mode: expected at least 2 channels "
295
+ f"(N means + 1 logvar), got {sample.shape[1]}"
296
+ )
297
+
298
+ # Extract means (first N channels) and logvar (last 1 channel)
299
+ means = sample[:, :-1, ...] # (B, N, ...)
300
+ logvar = sample[:, -1:, ...] # (B, 1, ...)
301
+
302
+ # Repeat logvar N times to match means channels
303
+ # Use expand/repeat pattern that works for both 4D and 5D tensors
304
+ num_channels = means.shape[1]
305
+ repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2)
306
+ repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...)
307
+
308
+ # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar]
309
+ sample = torch.cat([means, repeated_logvar], dim=1)
310
+ elif self.latent_log_var == LogVarianceType.CONSTANT:
311
+ sample = sample[:, :-1, ...]
312
+ approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects
313
+ sample = torch.cat(
314
+ [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
315
+ dim=1,
316
+ )
317
+
318
+ # Split into means and logvar, then normalize means
319
+ means, _ = torch.chunk(sample, 2, dim=1)
320
+ return self.per_channel_statistics.normalize(means)
321
+
322
+ def tiled_encode(
323
+ self,
324
+ video: torch.Tensor,
325
+ tiling_config: TilingConfig | None = None,
326
+ ) -> torch.Tensor:
327
+ """Encode video to latent using tiled processing of the given video tensor.
328
+ Device Handling:
329
+ - Input video can be on CPU or GPU
330
+ - Accumulation buffers are created on model's device
331
+ - Each tile is automatically moved to model's device before encoding
332
+ - Output latent is returned on model's device
333
+ Args:
334
+ video: Input video tensor (B, 3, F, H, W) in range [-1, 1]
335
+ tiling_config: Tiling configuration for the video tensor
336
+ Returns:
337
+ Latent tensor (B, 128, F', H', W') on model's device
338
+ where F' = 1 + (F-1)/8, H' = H/32, W' = W/32
339
+ """
340
+ # Detect model device and dtype
341
+ model_device = next(self.parameters()).device
342
+ model_dtype = next(self.parameters()).dtype
343
+
344
+ # Extract shape components
345
+ batch, _, frames, height, width = video.shape
346
+
347
+ # Check frame count and crop if needed
348
+ if (frames - 1) % VIDEO_SCALE_FACTORS.time != 0:
349
+ frames_to_crop = (frames - 1) % VIDEO_SCALE_FACTORS.time
350
+ logger.warning(
351
+ f"Number of frames {frames} of input video is not ({VIDEO_SCALE_FACTORS.time} * k + 1), "
352
+ f"last {frames_to_crop} frames will be cropped"
353
+ )
354
+ video = video[:, :, :-frames_to_crop, ...]
355
+ # Update frames after cropping
356
+ frames = video.shape[2]
357
+
358
+ # Calculate output latent shape (inverse of upscale)
359
+ latent_shape = VideoLatentShape(
360
+ batch=batch,
361
+ channels=self.latent_channels, # 128 for standard VAE
362
+ frames=(frames - 1) // VIDEO_SCALE_FACTORS.time + 1,
363
+ height=height // VIDEO_SCALE_FACTORS.height,
364
+ width=width // VIDEO_SCALE_FACTORS.width,
365
+ )
366
+
367
+ # Prepare tiles (operates on VIDEO dimensions)
368
+ tiles = prepare_tiles_for_encoding(video, tiling_config)
369
+
370
+ # Initialize accumulation buffers on model device
371
+ latent_buffer = torch.zeros(
372
+ latent_shape.to_torch_shape(),
373
+ device=model_device,
374
+ dtype=model_dtype,
375
+ )
376
+ weights_buffer = torch.zeros_like(latent_buffer)
377
+
378
+ # Process each tile
379
+ for tile in tiles:
380
+ # Extract video tile from input (may be on CPU)
381
+ video_tile = video[tile.in_coords]
382
+
383
+ # Move tile to model device if needed
384
+ if video_tile.device != model_device or video_tile.dtype != model_dtype:
385
+ video_tile = video_tile.to(device=model_device, dtype=model_dtype)
386
+
387
+ # Encode tile to latent (output on model device)
388
+ latent_tile = self.forward(video_tile)
389
+
390
+ # Move blend mask to model device
391
+ mask = tile.blend_mask.to(
392
+ device=model_device,
393
+ dtype=model_dtype,
394
+ )
395
+
396
+ # Weighted accumulation in latent space
397
+ latent_buffer[tile.out_coords] += latent_tile * mask
398
+ weights_buffer[tile.out_coords] += mask
399
+
400
+ del latent_tile, mask, video_tile
401
+
402
+ # Normalize by accumulated weights
403
+ weights_buffer = weights_buffer.clamp(min=1e-8)
404
+ return latent_buffer / weights_buffer
405
+
406
+
407
+ def prepare_tiles_for_encoding(
408
+ video: torch.Tensor,
409
+ tiling_config: TilingConfig | None = None,
410
+ ) -> List[Tile]:
411
+ """Prepare tiles for VAE encoding.
412
+ Args:
413
+ video: Input video tensor (B, 3, F, H, W) in range [-1, 1]
414
+ tiling_config: Tiling configuration for the video tensor
415
+ Returns:
416
+ List of tiles for the video tensor
417
+ """
418
+
419
+ splitters = [DEFAULT_SPLIT_OPERATION] * len(video.shape)
420
+ mappers = [DEFAULT_MAPPING_OPERATION] * len(video.shape)
421
+ minimum_spatial_overlap_px = 64
422
+ minimum_temporal_overlap_frames = 16
423
+
424
+ if tiling_config is not None and tiling_config.spatial_config is not None:
425
+ cfg = tiling_config.spatial_config
426
+
427
+ tile_size_px = cfg.tile_size_in_pixels
428
+ overlap_px = cfg.tile_overlap_in_pixels
429
+
430
+ # Set minimum spatial overlap to 64 pixels in order to allow cutting padding from
431
+ # the front and back of the tiles and concatenate tiles without artifacts.
432
+ # The encoder uses symmetric padding (pad=1) in H and W at each conv layer. At tile
433
+ # boundaries, convs see padding (zeros/reflect) instead of real neighbor pixels, causing
434
+ # incorrect context near edges.
435
+ # For each overlap we discard 1 latent per edge (32px at scale 32) and concatenate tiles at a
436
+ # shared region with the next tile.
437
+ if overlap_px < minimum_spatial_overlap_px:
438
+ logger.warning(
439
+ f"Overlap pixels {overlap_px} in spatial tiling is less than \
440
+ {minimum_spatial_overlap_px}, setting to minimum required {minimum_spatial_overlap_px}"
441
+ )
442
+ overlap_px = minimum_spatial_overlap_px
443
+
444
+ # Define split and map operations for the spatial dimensions
445
+
446
+ # Height axis (H)
447
+ splitters[3] = split_with_symmetric_overlaps(tile_size_px, overlap_px)
448
+ mappers[3] = make_mapping_operation(map_spatial_interval_to_latent, scale=VIDEO_SCALE_FACTORS.height)
449
+
450
+ # Width axis (W)
451
+ splitters[4] = split_with_symmetric_overlaps(tile_size_px, overlap_px)
452
+ mappers[4] = make_mapping_operation(map_spatial_interval_to_latent, scale=VIDEO_SCALE_FACTORS.width)
453
+
454
+ if tiling_config is not None and tiling_config.temporal_config is not None:
455
+ cfg = tiling_config.temporal_config
456
+ tile_size_frames = cfg.tile_size_in_frames
457
+ overlap_frames = cfg.tile_overlap_in_frames
458
+
459
+ if overlap_frames < minimum_temporal_overlap_frames:
460
+ logger.warning(f"Overlap frames {overlap_frames} is less than 16, setting to minimum required 16")
461
+ overlap_frames = minimum_temporal_overlap_frames
462
+
463
+ splitters[2] = split_temporal_frames(tile_size_frames, overlap_frames)
464
+ mappers[2] = make_mapping_operation(map_temporal_interval_to_latent, scale=VIDEO_SCALE_FACTORS.time)
465
+
466
+ return create_tiles(video.shape, splitters, mappers)
467
+
468
+
469
+ def _make_decoder_block(
470
+ block_name: str,
471
+ block_config: dict[str, Any],
472
+ in_channels: int,
473
+ convolution_dimensions: int,
474
+ norm_layer: NormLayerType,
475
+ timestep_conditioning: bool,
476
+ norm_num_groups: int,
477
+ spatial_padding_mode: PaddingModeType,
478
+ ) -> Tuple[nn.Module, int]:
479
+ out_channels = in_channels
480
+ if block_name == "res_x":
481
+ block = UNetMidBlock3D(
482
+ dims=convolution_dimensions,
483
+ in_channels=in_channels,
484
+ num_layers=block_config["num_layers"],
485
+ resnet_eps=1e-6,
486
+ resnet_groups=norm_num_groups,
487
+ norm_layer=norm_layer,
488
+ inject_noise=block_config.get("inject_noise", False),
489
+ timestep_conditioning=timestep_conditioning,
490
+ spatial_padding_mode=spatial_padding_mode,
491
+ )
492
+ elif block_name == "attn_res_x":
493
+ block = UNetMidBlock3D(
494
+ dims=convolution_dimensions,
495
+ in_channels=in_channels,
496
+ num_layers=block_config["num_layers"],
497
+ resnet_groups=norm_num_groups,
498
+ norm_layer=norm_layer,
499
+ inject_noise=block_config.get("inject_noise", False),
500
+ timestep_conditioning=timestep_conditioning,
501
+ attention_head_dim=block_config["attention_head_dim"],
502
+ spatial_padding_mode=spatial_padding_mode,
503
+ )
504
+ elif block_name == "res_x_y":
505
+ out_channels = in_channels // block_config.get("multiplier", 2)
506
+ block = ResnetBlock3D(
507
+ dims=convolution_dimensions,
508
+ in_channels=in_channels,
509
+ out_channels=out_channels,
510
+ eps=1e-6,
511
+ groups=norm_num_groups,
512
+ norm_layer=norm_layer,
513
+ inject_noise=block_config.get("inject_noise", False),
514
+ timestep_conditioning=False,
515
+ spatial_padding_mode=spatial_padding_mode,
516
+ )
517
+ elif block_name == "compress_time":
518
+ out_channels = in_channels // block_config.get("multiplier", 1)
519
+ block = DepthToSpaceUpsample(
520
+ dims=convolution_dimensions,
521
+ in_channels=in_channels,
522
+ stride=(2, 1, 1),
523
+ out_channels_reduction_factor=block_config.get("multiplier", 1),
524
+ spatial_padding_mode=spatial_padding_mode,
525
+ )
526
+ elif block_name == "compress_space":
527
+ out_channels = in_channels // block_config.get("multiplier", 1)
528
+ block = DepthToSpaceUpsample(
529
+ dims=convolution_dimensions,
530
+ in_channels=in_channels,
531
+ stride=(1, 2, 2),
532
+ out_channels_reduction_factor=block_config.get("multiplier", 1),
533
+ spatial_padding_mode=spatial_padding_mode,
534
+ )
535
+ elif block_name == "compress_all":
536
+ out_channels = in_channels // block_config.get("multiplier", 1)
537
+ block = DepthToSpaceUpsample(
538
+ dims=convolution_dimensions,
539
+ in_channels=in_channels,
540
+ stride=(2, 2, 2),
541
+ residual=block_config.get("residual", False),
542
+ out_channels_reduction_factor=block_config.get("multiplier", 1),
543
+ spatial_padding_mode=spatial_padding_mode,
544
+ )
545
+ else:
546
+ raise ValueError(f"unknown layer: {block_name}")
547
+
548
+ return block, out_channels
549
+
550
+
551
+ class VideoDecoder(nn.Module):
552
+ _DEFAULT_NORM_NUM_GROUPS = 32
553
+ """
554
+ Variational Autoencoder Decoder. Decodes latent representation into video frames.
555
+ The decoder upsamples latents through a series of upsampling operations (inverse of encoder).
556
+ Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration.
557
+ Upsampling blocks expand dimensions by 2x in specified dimensions:
558
+ - "compress_time": temporal only
559
+ - "compress_space": spatial only (H and W)
560
+ - "compress_all": all dimensions (F, H, W)
561
+ - "res_x" / "res_x_y" / "attn_res_x": no upsampling
562
+ Causal Mode:
563
+ causal=False (standard): Symmetric padding, allows future frame dependencies.
564
+ causal=True: Causal padding, each frame depends only on past/current frames.
565
+ First frame removed after temporal upsampling in both modes. Output shape unchanged.
566
+ Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes.
567
+ Args:
568
+ convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).
569
+ in_channels: The number of input channels (latent channels). Default is 128.
570
+ out_channels: The number of output channels. For RGB images, this is 3.
571
+ decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params)
572
+ where params is either an int (num_layers) or a dict with configuration.
573
+ patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion:
574
+ H -> Hx4, W -> Wx4. Should be a power of 2.
575
+ norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
576
+ causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding.
577
+ When True, uses causal padding (past/current frames only).
578
+ timestep_conditioning: Whether to condition the decoder on timestep for denoising.
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ convolution_dimensions: int = 3,
584
+ in_channels: int = 128,
585
+ out_channels: int = 3,
586
+ decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006
587
+ patch_size: int = 4,
588
+ norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
589
+ causal: bool = False,
590
+ timestep_conditioning: bool = False,
591
+ decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
592
+ base_channels: int = 128,
593
+ ):
594
+ super().__init__()
595
+
596
+ # Spatiotemporal downscaling between decoded video space and VAE latents.
597
+ # According to the LTXV paper, the standard configuration downsamples
598
+ # video inputs by a factor of 8 in the temporal dimension and 32 in
599
+ # each spatial dimension (height and width). This parameter determines how
600
+ # many video frames and pixels correspond to a single latent cell.
601
+ self.video_downscale_factors = SpatioTemporalScaleFactors(
602
+ time=8,
603
+ width=32,
604
+ height=32,
605
+ )
606
+
607
+ self.patch_size = patch_size
608
+ out_channels = out_channels * patch_size**2
609
+ self.causal = causal
610
+ self.timestep_conditioning = timestep_conditioning
611
+ self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
612
+
613
+ # Per-channel statistics for denormalizing latents
614
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
615
+
616
+ # Noise and timestep parameters for decoder conditioning
617
+ self.decode_noise_scale = 0.025
618
+ self.decode_timestep = 0.05
619
+
620
+ # LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.
621
+ # Hence the total feature_channels is multiplied by 8 (2^3).
622
+ feature_channels = base_channels * 8
623
+
624
+ self.conv_in = make_conv_nd(
625
+ dims=convolution_dimensions,
626
+ in_channels=in_channels,
627
+ out_channels=feature_channels,
628
+ kernel_size=3,
629
+ stride=1,
630
+ padding=1,
631
+ causal=True,
632
+ spatial_padding_mode=decoder_spatial_padding_mode,
633
+ )
634
+
635
+ self.up_blocks = nn.ModuleList([])
636
+
637
+ for block_name, block_params in list(reversed(decoder_blocks)):
638
+ # Convert int to dict format for uniform handling
639
+ block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
640
+
641
+ block, feature_channels = _make_decoder_block(
642
+ block_name=block_name,
643
+ block_config=block_config,
644
+ in_channels=feature_channels,
645
+ convolution_dimensions=convolution_dimensions,
646
+ norm_layer=norm_layer,
647
+ timestep_conditioning=timestep_conditioning,
648
+ norm_num_groups=self._norm_num_groups,
649
+ spatial_padding_mode=decoder_spatial_padding_mode,
650
+ )
651
+
652
+ self.up_blocks.append(block)
653
+
654
+ if norm_layer == NormLayerType.GROUP_NORM:
655
+ self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)
656
+ elif norm_layer == NormLayerType.PIXEL_NORM:
657
+ self.conv_norm_out = PixelNorm()
658
+
659
+ self.conv_act = nn.SiLU()
660
+ self.conv_out = make_conv_nd(
661
+ dims=convolution_dimensions,
662
+ in_channels=feature_channels,
663
+ out_channels=out_channels,
664
+ kernel_size=3,
665
+ padding=1,
666
+ causal=True,
667
+ spatial_padding_mode=decoder_spatial_padding_mode,
668
+ )
669
+
670
+ if timestep_conditioning:
671
+ self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0))
672
+ self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
673
+ embedding_dim=feature_channels * 2, size_emb_dim=0
674
+ )
675
+ self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels))
676
+
677
+ def forward(
678
+ self,
679
+ sample: torch.Tensor,
680
+ timestep: torch.Tensor | None = None,
681
+ generator: torch.Generator | None = None,
682
+ ) -> torch.Tensor:
683
+ r"""
684
+ Decode latent representation into video frames.
685
+ Args:
686
+ sample: Latent tensor (B, 128, F', H', W').
687
+ timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None.
688
+ generator: Random generator for deterministic noise injection (if inject_noise=True in blocks).
689
+ Returns:
690
+ Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'.
691
+ Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512).
692
+ Note: First frame is removed after temporal upsampling regardless of causal mode.
693
+ When causal=False, allows future frame dependencies in convolutions but maintains same output shape.
694
+ """
695
+ batch_size = sample.shape[0]
696
+
697
+ # Add noise if timestep conditioning is enabled
698
+ if self.timestep_conditioning:
699
+ noise = (
700
+ torch.randn(
701
+ sample.size(),
702
+ generator=generator,
703
+ dtype=sample.dtype,
704
+ device=sample.device,
705
+ )
706
+ * self.decode_noise_scale
707
+ )
708
+
709
+ sample = noise + (1.0 - self.decode_noise_scale) * sample
710
+
711
+ # Denormalize latents
712
+ sample = self.per_channel_statistics.un_normalize(sample)
713
+
714
+ # Use default decode_timestep if timestep not provided
715
+ if timestep is None and self.timestep_conditioning:
716
+ timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype)
717
+
718
+ sample = self.conv_in(sample, causal=self.causal)
719
+
720
+ scaled_timestep = None
721
+ if self.timestep_conditioning:
722
+ if timestep is None:
723
+ raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
724
+ scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample)
725
+
726
+ for up_block in self.up_blocks:
727
+ if isinstance(up_block, UNetMidBlock3D):
728
+ block_kwargs = {
729
+ "causal": self.causal,
730
+ "timestep": scaled_timestep if self.timestep_conditioning else None,
731
+ "generator": generator,
732
+ }
733
+ sample = up_block(sample, **block_kwargs)
734
+ elif isinstance(up_block, ResnetBlock3D):
735
+ sample = up_block(sample, causal=self.causal, generator=generator)
736
+ else:
737
+ sample = up_block(sample, causal=self.causal)
738
+
739
+ sample = self.conv_norm_out(sample)
740
+
741
+ if self.timestep_conditioning:
742
+ embedded_timestep = self.last_time_embedder(
743
+ timestep=scaled_timestep.flatten(),
744
+ hidden_dtype=sample.dtype,
745
+ )
746
+ embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1)
747
+ ada_values = self.last_scale_shift_table[None, ..., None, None, None].to(
748
+ device=sample.device, dtype=sample.dtype
749
+ ) + embedded_timestep.reshape(
750
+ batch_size,
751
+ 2,
752
+ -1,
753
+ embedded_timestep.shape[-3],
754
+ embedded_timestep.shape[-2],
755
+ embedded_timestep.shape[-1],
756
+ )
757
+ shift, scale = ada_values.unbind(dim=1)
758
+ sample = sample * (1 + scale) + shift
759
+
760
+ sample = self.conv_act(sample)
761
+ sample = self.conv_out(sample, causal=self.causal)
762
+
763
+ # Final spatial expansion: reverse the initial patchify from encoder
764
+ # Moves pixels from channels back to spatial dimensions
765
+ # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4
766
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
767
+
768
+ return sample
769
+
770
+ def _prepare_tiles(
771
+ self,
772
+ latent: torch.Tensor,
773
+ tiling_config: TilingConfig | None = None,
774
+ ) -> List[Tile]:
775
+ splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape)
776
+ mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape)
777
+ if tiling_config is not None and tiling_config.spatial_config is not None:
778
+ cfg = tiling_config.spatial_config
779
+ long_side = max(latent.shape[3], latent.shape[4])
780
+
781
+ def enable_on_axis(axis_idx: int, factor: int) -> None:
782
+ size = cfg.tile_size_in_pixels // factor
783
+ overlap = cfg.tile_overlap_in_pixels // factor
784
+ axis_length = latent.shape[axis_idx]
785
+ lower_threshold = max(2, overlap + 1)
786
+ tile_size = max(lower_threshold, round(size * axis_length / long_side))
787
+ splitters[axis_idx] = split_with_symmetric_overlaps(tile_size, overlap)
788
+ mappers[axis_idx] = make_mapping_operation(map_spatial_interval_to_pixel, scale=factor)
789
+
790
+ enable_on_axis(3, self.video_downscale_factors.height)
791
+ enable_on_axis(4, self.video_downscale_factors.width)
792
+
793
+ if tiling_config is not None and tiling_config.temporal_config is not None:
794
+ cfg = tiling_config.temporal_config
795
+ tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time
796
+ overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time
797
+ splitters[2] = split_temporal_latents(tile_size, overlap)
798
+ mappers[2] = make_mapping_operation(map_temporal_interval_to_frame, scale=self.video_downscale_factors.time)
799
+
800
+ return create_tiles(latent.shape, splitters, mappers)
801
+
802
+ def tiled_decode(
803
+ self,
804
+ latent: torch.Tensor,
805
+ tiling_config: TilingConfig | None = None,
806
+ timestep: torch.Tensor | None = None,
807
+ generator: torch.Generator | None = None,
808
+ ) -> Iterator[torch.Tensor]:
809
+ """
810
+ Decode a latent tensor into video frames using tiled processing.
811
+ Splits the latent tensor into tiles, decodes each tile individually,
812
+ and yields video chunks as they become available.
813
+ Args:
814
+ latent: Input latent tensor (B, C, F', H', W').
815
+ tiling_config: Tiling configuration for the latent tensor.
816
+ timestep: Optional timestep for decoder conditioning.
817
+ generator: Optional random generator for deterministic decoding.
818
+ Yields:
819
+ Video chunks (B, C, T, H, W) by temporal slices;
820
+ """
821
+
822
+ # Calculate full video shape from latent shape to get spatial dimensions
823
+ full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors)
824
+ tiles = self._prepare_tiles(latent, tiling_config)
825
+
826
+ temporal_groups = self._group_tiles_by_temporal_slice(tiles)
827
+
828
+ # State for temporal overlap handling
829
+ previous_chunk = None
830
+ previous_weights = None
831
+ previous_temporal_slice = None
832
+
833
+ for temporal_group_tiles in temporal_groups:
834
+ curr_temporal_slice = temporal_group_tiles[0].out_coords[2]
835
+
836
+ # Calculate the shape of the temporal buffer for this group of tiles.
837
+ # The temporal length depends on whether this is the first tile (starts at 0) or not.
838
+ # - First tile: (frames - 1) * scale + 1
839
+ # - Subsequent tiles: frames * scale
840
+ # This logic is handled by TemporalAxisMapping and reflected in out_coords.
841
+ temporal_tile_buffer_shape = full_video_shape._replace(
842
+ frames=curr_temporal_slice.stop - curr_temporal_slice.start,
843
+ )
844
+
845
+ buffer = torch.zeros(
846
+ temporal_tile_buffer_shape.to_torch_shape(),
847
+ device=latent.device,
848
+ dtype=latent.dtype,
849
+ )
850
+
851
+ curr_weights = self._accumulate_temporal_group_into_buffer(
852
+ group_tiles=temporal_group_tiles,
853
+ buffer=buffer,
854
+ latent=latent,
855
+ timestep=timestep,
856
+ generator=generator,
857
+ )
858
+
859
+ # Blend with previous temporal chunk if it exists
860
+ if previous_chunk is not None:
861
+ # Check if current temporal slice overlaps with previous temporal slice
862
+ if previous_temporal_slice.stop > curr_temporal_slice.start:
863
+ overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start
864
+ temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None)
865
+
866
+ # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer
867
+ # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add
868
+ # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the
869
+ # previous buffers, then later normalize by weights.
870
+ previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :]
871
+ previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[
872
+ :, :, slice(0, overlap_len), :, :
873
+ ]
874
+
875
+ buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :]
876
+ curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[
877
+ :, :, temporal_overlap_slice, :, :
878
+ ]
879
+
880
+ # Yield the non-overlapping part of the previous chunk
881
+ previous_weights = previous_weights.clamp(min=1e-8)
882
+ yield_len = curr_temporal_slice.start - previous_temporal_slice.start
883
+ yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :]
884
+
885
+ # Update state for next iteration
886
+ previous_chunk = buffer
887
+ previous_weights = curr_weights
888
+ previous_temporal_slice = curr_temporal_slice
889
+
890
+ # Yield any remaining chunk
891
+ if previous_chunk is not None:
892
+ previous_weights = previous_weights.clamp(min=1e-8)
893
+ yield previous_chunk / previous_weights
894
+
895
+ def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]:
896
+ """Group tiles by their temporal output slice."""
897
+ if not tiles:
898
+ return []
899
+
900
+ groups = []
901
+ current_slice = tiles[0].out_coords[2]
902
+ current_group = []
903
+
904
+ for tile in tiles:
905
+ tile_slice = tile.out_coords[2]
906
+ if tile_slice == current_slice:
907
+ current_group.append(tile)
908
+ else:
909
+ groups.append(current_group)
910
+ current_slice = tile_slice
911
+ current_group = [tile]
912
+
913
+ # Add the final group
914
+ if current_group:
915
+ groups.append(current_group)
916
+
917
+ return groups
918
+
919
+ def _accumulate_temporal_group_into_buffer(
920
+ self,
921
+ group_tiles: List[Tile],
922
+ buffer: torch.Tensor,
923
+ latent: torch.Tensor,
924
+ timestep: torch.Tensor | None,
925
+ generator: torch.Generator | None,
926
+ ) -> torch.Tensor:
927
+ """
928
+ Decode and accumulate all tiles of a temporal group into a local buffer.
929
+ The buffer is local to the group and always starts at time 0; temporal coordinates
930
+ are rebased by subtracting temporal_slice.start.
931
+ """
932
+ temporal_slice = group_tiles[0].out_coords[2]
933
+
934
+ weights = torch.zeros_like(buffer)
935
+
936
+ for tile in group_tiles:
937
+ decoded_tile = self.forward(latent[tile.in_coords], timestep, generator)
938
+ mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype)
939
+ temporal_offset = tile.out_coords[2].start - temporal_slice.start
940
+ # Use the tile's output coordinate length, not the decoded tile's length,
941
+ # as the decoder may produce a different number of frames than expected
942
+ expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start
943
+ decoded_temporal_len = decoded_tile.shape[2]
944
+
945
+ # Ensure we don't exceed the buffer or decoded tile bounds
946
+ actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset)
947
+
948
+ chunk_coords = (
949
+ slice(None), # batch
950
+ slice(None), # channels
951
+ slice(temporal_offset, temporal_offset + actual_temporal_len),
952
+ tile.out_coords[3], # height
953
+ tile.out_coords[4], # width
954
+ )
955
+
956
+ # Slice decoded_tile and mask to match the actual length we're writing
957
+ decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :]
958
+ mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask
959
+
960
+ buffer[chunk_coords] += decoded_slice * mask_slice
961
+ weights[chunk_coords] += mask_slice
962
+
963
+ return weights
964
+
965
+
966
+ def decode_video(
967
+ latent: torch.Tensor,
968
+ video_decoder: VideoDecoder,
969
+ tiling_config: TilingConfig | None = None,
970
+ generator: torch.Generator | None = None,
971
+ ) -> Iterator[torch.Tensor]:
972
+ """
973
+ Decode a video latent tensor with the given decoder.
974
+ Args:
975
+ latent: Tensor [c, f, h, w]
976
+ video_decoder: Decoder module.
977
+ tiling_config: Optional tiling settings.
978
+ generator: Optional random generator for deterministic decoding.
979
+ Yields:
980
+ Decoded chunk [f, h, w, c], uint8 in [0, 255].
981
+ """
982
+
983
+ def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor:
984
+ frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8)
985
+ frames = rearrange(frames[0], "c f h w -> f h w c")
986
+ return frames
987
+
988
+ if tiling_config is not None:
989
+ for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
990
+ yield convert_to_uint8(frames)
991
+ else:
992
+ decoded_video = video_decoder(latent, generator=generator)
993
+ yield convert_to_uint8(decoded_video)
994
+
995
+
996
+ def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:
997
+ """
998
+ Get the number of video chunks for a given number of frames and tiling configuration.
999
+ Args:
1000
+ num_frames: Number of frames in the video.
1001
+ tiling_config: Tiling configuration.
1002
+ Returns:
1003
+ Number of video chunks.
1004
+ """
1005
+ if not tiling_config or not tiling_config.temporal_config:
1006
+ return 1
1007
+ cfg = tiling_config.temporal_config
1008
+ frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames
1009
+ return (num_frames - 1 + frame_stride - 1) // frame_stride
1010
+
1011
+
1012
+ def split_with_symmetric_overlaps(size: int, overlap: int) -> SplitOperation:
1013
+ def split(dimension_size: int) -> DimensionIntervals:
1014
+ if dimension_size <= size:
1015
+ return DEFAULT_SPLIT_OPERATION(dimension_size)
1016
+ amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
1017
+ starts = [i * (size - overlap) for i in range(amount)]
1018
+ ends = [start + size for start in starts]
1019
+ ends[-1] = dimension_size
1020
+ left_ramps = [0] + [overlap] * (amount - 1)
1021
+ right_ramps = [overlap] * (amount - 1) + [0]
1022
+ return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
1023
+
1024
+ return split
1025
+
1026
+
1027
+ def split_temporal_latents(size: int, overlap: int) -> SplitOperation:
1028
+ """Split a temporal axis into overlapping tiles with causal handling.
1029
+ Example with size=24, overlap=8 (units are whatever axis you split):
1030
+ Non-causal split would produce:
1031
+ Tile 0: [0, 24), left_ramp=0, right_ramp=8
1032
+ Tile 1: [16, 40), left_ramp=8, right_ramp=8
1033
+ Tile 2: [32, 56), left_ramp=8, right_ramp=0
1034
+ Causal split produces:
1035
+ Tile 0: [0, 24), left_ramp=0, right_ramp=8 (unchanged - starts at anchor)
1036
+ Tile 1: [15, 40), left_ramp=9, right_ramp=8 (shifted back 1, ramp +1)
1037
+ Tile 2: [31, 56), left_ramp=9, right_ramp=0 (shifted back 1, ramp +1)
1038
+ This ensures each tile can causally depend on frames from previous tiles while maintaining
1039
+ proper temporal continuity through the blend ramps.
1040
+ Args:
1041
+ size: Tile size in *axis units* (latent steps for LTX time tiling)
1042
+ overlap: Overlap between tiles in the same units
1043
+ Returns:
1044
+ Split operation that divides temporal dimension with causal handling
1045
+ """
1046
+ non_causal_split = split_with_symmetric_overlaps(size, overlap)
1047
+
1048
+ def split(dimension_size: int) -> DimensionIntervals:
1049
+ if dimension_size <= size:
1050
+ return DEFAULT_SPLIT_OPERATION(dimension_size)
1051
+ intervals = non_causal_split(dimension_size)
1052
+
1053
+ starts = intervals.starts
1054
+ starts[1:] = [s - 1 for s in starts[1:]]
1055
+
1056
+ # Extend blend ramps by 1 for non-first tiles to blend over the extra frame
1057
+ left_ramps = intervals.left_ramps
1058
+ left_ramps[1:] = [r + 1 for r in left_ramps[1:]]
1059
+
1060
+ return replace(intervals, starts=starts, left_ramps=left_ramps)
1061
+
1062
+ return split
1063
+
1064
+
1065
+ def split_temporal_frames(tile_size_frames: int, overlap_frames: int) -> SplitOperation:
1066
+ """Split a temporal axis in video frame space into overlapping tiles.
1067
+ Args:
1068
+ tile_size_frames: Tile length in frames.
1069
+ overlap_frames: Overlap between consecutive tiles in frames.
1070
+ Returns:
1071
+ Split operation that takes frame count and returns DimensionIntervals in frame indices.
1072
+ """
1073
+ non_causal_split = split_with_symmetric_overlaps(tile_size_frames, overlap_frames)
1074
+
1075
+ def split(dimension_size: int) -> DimensionIntervals:
1076
+ if dimension_size <= tile_size_frames:
1077
+ return DEFAULT_SPLIT_OPERATION(dimension_size)
1078
+ intervals = non_causal_split(dimension_size)
1079
+ ends = intervals.ends
1080
+ ends[:-1] = [e + 1 for e in ends[:-1]]
1081
+ right_ramps = [0] * len(intervals.right_ramps)
1082
+ return replace(intervals, ends=ends, right_ramps=right_ramps)
1083
+
1084
+ return split
1085
+
1086
+
1087
+ def make_mapping_operation(
1088
+ map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor | None]],
1089
+ scale: int,
1090
+ ) -> MappingOperation:
1091
+ """Create a mapping operation over a set of tiling intervals.
1092
+ The given mapping function is applied to each interval in the input dimension. The result function is used for
1093
+ creating tiles in the output dimension.
1094
+ Args:
1095
+ map_func: Mapping function to create the mapping operation from
1096
+ scale: Scale factor for the transformation, used as an argument for the mapping function
1097
+ Returns:
1098
+ Mapping operation that takes a set of tiling intervals and returns a set of slices and masks in the output
1099
+ dimension.
1100
+ """
1101
+
1102
+ def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]:
1103
+ output_slices: list[slice] = []
1104
+ masks_1d: list[torch.Tensor | None] = []
1105
+ number_of_slices = len(intervals.starts)
1106
+ for i in range(number_of_slices):
1107
+ start = intervals.starts[i]
1108
+ end = intervals.ends[i]
1109
+ left_ramp = intervals.left_ramps[i]
1110
+ right_ramp = intervals.right_ramps[i]
1111
+ output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale)
1112
+ output_slices.append(output_slice)
1113
+ masks_1d.append(mask_1d)
1114
+ return output_slices, masks_1d
1115
+
1116
+ return map_op
1117
+
1118
+
1119
+ def map_temporal_interval_to_frame(
1120
+ begin: int,
1121
+ end: int,
1122
+ left_ramp: int,
1123
+ right_ramp: int,
1124
+ scale: int,
1125
+ ) -> Tuple[slice, torch.Tensor]:
1126
+ """Map temporal interval in latent space to video frame space.
1127
+ Args:
1128
+ begin: Start position in latent space
1129
+ end: End position in latent space
1130
+ left_ramp: Left ramp size in latent space
1131
+ right_ramp: Right ramp size in latent space
1132
+ scale: Scale factor for transformation
1133
+ Returns:
1134
+ Tuple of (output_slice, blend_mask)
1135
+ """
1136
+ start = begin * scale
1137
+ stop = 1 + (end - 1) * scale
1138
+
1139
+ left_ramp_frames = 0 if left_ramp == 0 else 1 + (left_ramp - 1) * scale
1140
+ right_ramp_frames = right_ramp * scale
1141
+
1142
+ mask_1d = compute_trapezoidal_mask_1d(stop - start, left_ramp_frames, right_ramp_frames, True)
1143
+ return slice(start, stop), mask_1d
1144
+
1145
+
1146
+ def map_temporal_interval_to_latent(
1147
+ begin: int, end: int, left_ramp: int, right_ramp: int | None = None, scale: int = 1
1148
+ ) -> Tuple[slice, torch.Tensor]:
1149
+ """
1150
+ Map temporal interval in video frame space to latent space.
1151
+ Args:
1152
+ begin: Start position in video frame space
1153
+ end: End position in video frame space
1154
+ left_ramp: Left ramp size in video frame space
1155
+ right_ramp: Right ramp size in video frame space
1156
+ scale: Scale factor for transformation
1157
+ Returns:
1158
+ Tuple of (output_slice, blend_mask)
1159
+ """
1160
+ start = begin // scale
1161
+ stop = (end - 1) // scale + 1
1162
+
1163
+ left_ramp_latents = 0 if left_ramp == 0 else 1 + (left_ramp - 1) // scale
1164
+ right_ramp_latents = right_ramp // scale
1165
+
1166
+ if right_ramp_latents != 0:
1167
+ raise ValueError("For tiled encoding, temporal tiles are expected to have a right ramp equal to 0")
1168
+
1169
+ mask_1d = compute_rectangular_mask_1d(stop - start, left_ramp_latents, right_ramp_latents)
1170
+
1171
+ return slice(start, stop), mask_1d
1172
+
1173
+
1174
+ def map_spatial_interval_to_pixel(
1175
+ begin: int,
1176
+ end: int,
1177
+ left_ramp: int,
1178
+ right_ramp: int,
1179
+ scale: int,
1180
+ ) -> Tuple[slice, torch.Tensor]:
1181
+ """Map spatial interval in latent space to pixel space.
1182
+ Args:
1183
+ begin: Start position in latent space
1184
+ end: End position in latent space
1185
+ left_ramp: Left ramp size in latent space
1186
+ right_ramp: Right ramp size in latent space
1187
+ scale: Scale factor for transformation
1188
+ """
1189
+ start = begin * scale
1190
+ stop = end * scale
1191
+ mask_1d = compute_trapezoidal_mask_1d(stop - start, left_ramp * scale, right_ramp * scale, False)
1192
+ return slice(start, stop), mask_1d
1193
+
1194
+
1195
+ def map_spatial_interval_to_latent(
1196
+ begin: int,
1197
+ end: int,
1198
+ left_ramp: int,
1199
+ right_ramp: int,
1200
+ scale: int,
1201
+ ) -> Tuple[slice, torch.Tensor]:
1202
+ """Map spatial interval in pixel space to latent space.
1203
+ Args:
1204
+ begin: Start position in pixel space
1205
+ end: End position in pixel space
1206
+ left_ramp: Left ramp size in pixel space
1207
+ right_ramp: Right ramp size in pixel space
1208
+ scale: Scale factor for transformation
1209
+ Returns:
1210
+ Tuple of (output_slice, blend_mask)
1211
+ """
1212
+ start = begin // scale
1213
+ stop = end // scale
1214
+ left_ramp = max(0, left_ramp // scale - 1)
1215
+
1216
+ right_ramp = 0 if right_ramp == 0 else 1
1217
+
1218
+ mask_1d = compute_rectangular_mask_1d(stop - start, left_ramp, right_ramp)
1219
+ return slice(start, stop), mask_1d
packages/ltx-core/src/ltx_core/quantization/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (597 Bytes). View file
 
packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_cast.cpython-312.pyc ADDED
Binary file (7.76 kB). View file
 
packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_scaled_mm.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
packages/ltx-core/src/ltx_core/quantization/__pycache__/policy.cpython-312.pyc ADDED
Binary file (2.02 kB). View file