Manmay commited on
Commit
08c5e28
·
verified ·
1 Parent(s): 5d085de

DramaBox Space — initial app + vendored ltx2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +381 -0
  2. README.md +151 -31
  3. app.py +176 -0
  4. assets/silence_latent_frame.pt +3 -0
  5. configs/training_args.example.yaml +53 -0
  6. ltx2/ltx_core/__init__.py +0 -0
  7. ltx2/ltx_core/batch_split.py +95 -0
  8. ltx2/ltx_core/components/__init__.py +10 -0
  9. ltx2/ltx_core/components/diffusion_steps.py +106 -0
  10. ltx2/ltx_core/components/guiders.py +383 -0
  11. ltx2/ltx_core/components/noisers.py +35 -0
  12. ltx2/ltx_core/components/patchifiers.py +348 -0
  13. ltx2/ltx_core/components/protocols.py +101 -0
  14. ltx2/ltx_core/components/schedulers.py +130 -0
  15. ltx2/ltx_core/conditioning/__init__.py +19 -0
  16. ltx2/ltx_core/conditioning/exceptions.py +4 -0
  17. ltx2/ltx_core/conditioning/item.py +20 -0
  18. ltx2/ltx_core/conditioning/mask_utils.py +210 -0
  19. ltx2/ltx_core/conditioning/types/__init__.py +13 -0
  20. ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py +71 -0
  21. ltx2/ltx_core/conditioning/types/keyframe_cond.py +70 -0
  22. ltx2/ltx_core/conditioning/types/latent_cond.py +44 -0
  23. ltx2/ltx_core/conditioning/types/noise_mask_cond.py +45 -0
  24. ltx2/ltx_core/conditioning/types/reference_video_cond.py +91 -0
  25. ltx2/ltx_core/guidance/__init__.py +15 -0
  26. ltx2/ltx_core/guidance/perturbations.py +79 -0
  27. ltx2/ltx_core/layer_streaming.py +324 -0
  28. ltx2/ltx_core/loader/__init__.py +48 -0
  29. ltx2/ltx_core/loader/fuse_loras.py +133 -0
  30. ltx2/ltx_core/loader/kernels.py +72 -0
  31. ltx2/ltx_core/loader/module_ops.py +14 -0
  32. ltx2/ltx_core/loader/primitives.py +146 -0
  33. ltx2/ltx_core/loader/registry.py +84 -0
  34. ltx2/ltx_core/loader/sd_ops.py +139 -0
  35. ltx2/ltx_core/loader/sft_loader.py +66 -0
  36. ltx2/ltx_core/loader/single_gpu_model_builder.py +136 -0
  37. ltx2/ltx_core/modality_tiling.py +222 -0
  38. ltx2/ltx_core/model/__init__.py +8 -0
  39. ltx2/ltx_core/model/audio_vae/__init__.py +29 -0
  40. ltx2/ltx_core/model/audio_vae/attention.py +71 -0
  41. ltx2/ltx_core/model/audio_vae/audio_vae.py +508 -0
  42. ltx2/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
  43. ltx2/ltx_core/model/audio_vae/causality_axis.py +10 -0
  44. ltx2/ltx_core/model/audio_vae/downsample.py +110 -0
  45. ltx2/ltx_core/model/audio_vae/model_configurator.py +200 -0
  46. ltx2/ltx_core/model/audio_vae/ops.py +73 -0
  47. ltx2/ltx_core/model/audio_vae/resnet.py +176 -0
  48. ltx2/ltx_core/model/audio_vae/upsample.py +106 -0
  49. ltx2/ltx_core/model/audio_vae/vocoder.py +594 -0
  50. ltx2/ltx_core/model/common/__init__.py +9 -0
LICENSE ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LTX-2 Community License Agreement
2
+ License date: January 5, 2026
3
+
4
+
5
+ By using or distributing any portion or element of LTX-2, you agree
6
+ to be bound by this Agreement.
7
+
8
+ 1. Definitions.
9
+
10
+ "Agreement" means the terms and conditions for the license, use,
11
+ reproduction, and distribution of LTX-2 and the Complementary
12
+ Materials, as specified in this document.
13
+
14
+ "Control" means the direct or indirect ownership of more than
15
+ fifty percent (50%) of the voting securities or other ownership
16
+ interests, or the power to direct the management and policies of
17
+ such Entity through voting rights, contract, or otherwise.
18
+
19
+ "Data" means a collection of information and/or content extracted
20
+ from the dataset used with LTX-2, including to train, pretrain,
21
+ or otherwise evaluate LTX-2. The Data is not licensed under this
22
+ Agreement.
23
+
24
+ "Derivatives of LTX-2" means all modifications to LTX-2, works
25
+ based on LTX-2, or any other model which is created or initialized
26
+ by transfer of patterns of the weights, parameters, activations or
27
+ output of LTX-2, to the other model, in order to cause the other
28
+ model to perform similarly to LTX-2, including – but not limited
29
+ to - distillation methods entailing the use of intermediate data
30
+ representations or methods based on the generation of synthetic
31
+ data by LTX-2 for training the other model. For clarity, Derivatives
32
+ of LTX-2 include: (i) any fine-tuned or adapted weights, parameters,
33
+ or checkpoints derived from LTX-2; (ii) derivative model architectures
34
+ that incorporate or are based upon LTX-2's architecture; and
35
+ (iii) any modified or extended versions of the Complementary
36
+ Materials. All intellectual property rights in Derivatives of LTX-2
37
+ shall be subject to the terms of this Agreement, and you may not
38
+ claim exclusive ownership rights in any Derivatives of LTX-2 that
39
+ would restrict the rights granted herein.
40
+
41
+ "Entity" means any individual, corporation, partnership, limited
42
+ liability company, or other legal entity. For purposes of this
43
+ Agreement, an Entity shall be deemed to include, on an aggregative
44
+ basis, all subsidiaries, affiliates, and other companies under
45
+ common Control with such Entity. When determining whether an Entity
46
+ meets any threshold under this Agreement (including revenue
47
+ thresholds), all subsidiaries, affiliates, and companies under
48
+ common Control shall be considered collectively.
49
+
50
+ "Harm" includes but is not limited to physical, mental,
51
+ psychological, financial and reputational damage, pain, or loss.
52
+
53
+ "Licensor" or "Lightricks" means the owner that is granting the
54
+ license under this Agreement. For the purposes of this Agreement,
55
+ the Licensor is Lightricks Ltd.
56
+
57
+ "LTX-2" means the large language models, text/image/video/audio/3D
58
+ generation models, and multimodal large language models and their
59
+ software and algorithms, including trained model weights, parameters
60
+ (including optimizer states), machine-learning model code,
61
+ inference-enabling code, training-enabling code, fine-tuning
62
+ enabling code, accompanying source code, scripts, documentation,
63
+ tutorials, examples, and all other elements of the foregoing
64
+ distributed and made publicly available by Lightricks (including,
65
+ for example, at https://github.com/Lightricks/LTX-2) for the LTX-2
66
+ model released on January 5, 2026. This license is applicable to
67
+ all LTX-2 versions released since January 5, 2026, and all future
68
+ releases of LTX-2 under this license.
69
+
70
+ "Output" means the results of operating LTX-2 as embodied in
71
+ informational content resulting therefrom.
72
+
73
+ "you" (or "your") means an individual or legal Entity licensing
74
+ LTX-2 in accordance with this Agreement and/or making use of LTX-2
75
+ for whichever purpose and in any field of use, including usage of
76
+ LTX-2 in an end-use application - e.g. chatbot, translator, image
77
+ generator.
78
+
79
+ 2. Grant of License. Subject to the terms and conditions of this
80
+ Agreement, you are granted a non-exclusive, worldwide,
81
+ non-transferable and royalty-free limited license under Licensor's
82
+ intellectual property or other rights owned by Licensor embodied
83
+ in LTX-2 to use, reproduce, prepare, distribute, publicly display,
84
+ publicly perform, sublicense, copy, create derivative works of,
85
+ and make modifications to LTX-2, for any purpose, subject to the
86
+ restrictions set forth in Attachment A; provided however, that
87
+ Entities with annual revenues of at least $10,000,000 (the
88
+ "Commercial Entities") are required to obtain a paid commercial
89
+ use license in order to use LTX-2 and Derivatives of LTX-2,
90
+ subject to the terms and provisions of a different license (the
91
+ "Commercial Use Agreement"), as will be provided by the Licensor.
92
+ Commercial Entities interested in such a commercial license are
93
+ required to [contact Licensor](https://ltx.io/model/licensing).
94
+ Any commercial use of LTX-2 or Derivatives of LTX-2 by the
95
+ Commercial Entities not in accordance with this Agreement and/or
96
+ the Commercial Use Agreement is strictly prohibited and shall be
97
+ deemed a material breach of this Agreement. Such material breach
98
+ will be subject, in addition to any license fees owed to Licensor
99
+ for the period such Commercial Entity used LTX-2 (as will be
100
+ determined by Licensor), to liquidated damages, which will be paid
101
+ to Licensor immediately upon demand, in an amount equal to double
102
+ the amount that would otherwise have been paid by you for the
103
+ relevant period of time. Such amount reflects a reasonable estimation
104
+ of the losses and administrative costs incurred due to such breach.
105
+ You agree and understand that this remedy does not limit the Licensor's
106
+ right to pursue other remedies available at law or equity.
107
+
108
+ 3. Distribution and Redistribution. You may host for third parties
109
+ remote access purposes (e.g. software-as-a-service), reproduce
110
+ and distribute copies of LTX-2 or Derivatives of LTX-2 thereof in
111
+ any medium, with or without modifications, provided that you meet
112
+ the following conditions:
113
+
114
+ (a) Use-based restrictions as referenced in paragraph 4 and all
115
+ provisions of Attachment A MUST be included as an enforceable
116
+ provision by you in any type of legal agreement (e.g. a
117
+ license) governing the use and/or distribution of LTX-2 or
118
+ Derivatives of LTX-2, and you shall give notice to subsequent
119
+ users you distribute to, that LTX-2 or Derivatives of LTX-2
120
+ are subject to paragraph 4 and Attachment A in their entirety,
121
+ including all use restrictions and acceptable use policies;
122
+
123
+ (b) You must provide any third party recipients of LTX-2 or
124
+ Derivatives of LTX-2 a copy of this Agreement, including all
125
+ attachments and use policies. Any Derivative of LTX-2 (as
126
+ defined in Section 1, including but not limited to fine-tuned
127
+ weights, modified training code, models trained on Outputs, or
128
+ any other derivative) must be distributed exclusively under
129
+ the terms of this Agreement with a complete copy of this
130
+ license included;
131
+
132
+ (c) You must cause any modified files to carry prominent notices
133
+ stating that you changed the files;
134
+
135
+ (d) You must retain all copyright, patent, trademark, and
136
+ attribution notices excluding those notices that do not
137
+ pertain to any part of LTX-2, Derivatives of LTX-2.
138
+
139
+ You may add your own copyright statement to your modifications and
140
+ may provide additional or different license terms and conditions -
141
+ respecting paragraph 3(a) - for use, reproduction, or distribution
142
+ of your modifications, or for any such Derivatives of LTX-2 as a
143
+ whole, provided your use, reproduction, and distribution of LTX-2
144
+ otherwise complies with the conditions stated in this Agreement,
145
+ and you provide a complete copy of this Agreement with any such
146
+ use, reproduction and distribution of LTX-2 and any Derivatives
147
+ thereof.
148
+
149
+ 4. Use-based restrictions. The restrictions set forth in Attachment A
150
+ are considered Use-based restrictions. Therefore, you cannot use
151
+ LTX-2 and the Derivatives of LTX-2 in violation of the specified
152
+ restricted uses. You may use LTX-2 subject to this Agreement,
153
+ including only for lawful purposes and in accordance with the
154
+ Agreement. "Use" may include creating any content with, fine-tuning,
155
+ updating, running, training, evaluating and/or re-parametrizing
156
+ LTX-2. You shall require all of your users who use LTX-2 or a
157
+ Derivative of LTX-2 to comply with the terms of this paragraph 4.
158
+
159
+ 5. The Output You Generate. Except as set forth herein, Licensor
160
+ claims no rights in the Output you generate using LTX-2. You are
161
+ accountable for input you insert into LTX-2, the Output you
162
+ generate and its subsequent uses. No use of the Output can
163
+ contravene any provision as stated in the Agreement.
164
+
165
+ 6. Updates and Runtime Restrictions. To the maximum extent permitted
166
+ by law, Licensor reserves the right to restrict (remotely or
167
+ otherwise) usage of LTX-2 in violation of this Agreement, update
168
+ LTX-2 through electronic means, or modify the Output of LTX-2
169
+ based on updates. You shall undertake reasonable efforts to use
170
+ the latest version of LTX-2. Any use of the non-current version
171
+ of LTX-2 is done solely at your risk.
172
+
173
+ 7. Export Controls and Sanctions Compliance. You acknowledge that
174
+ LTX-2, Derivatives of LTX-2 may be subject to export control laws
175
+ and regulations, including but not limited to the U.S. Export
176
+ Administration Regulations and sanctions programs administered by
177
+ the Office of Foreign Assets Control (OFAC). You represent and
178
+ warrant that you and any users of LTX-2 are not (i) located in,
179
+ organized under the laws of, or ordinarily resident in any country
180
+ or territory subject to comprehensive sanctions; (ii) identified
181
+ on any U.S. government restricted party list, including the
182
+ Specially Designated Nationals and Blocked Persons List; or
183
+ (iii) otherwise prohibited from receiving LTX-2 under applicable
184
+ law. You shall not export, re-export, or transfer LTX-2, directly
185
+ or indirectly, in violation of any applicable export control or
186
+ sanctions laws or regulations. You agree to comply with all
187
+ applicable trade control laws and shall indemnify and hold
188
+ Licensor harmless from any claims arising from your failure to
189
+ comply with such laws.
190
+
191
+ 8. Trademarks and related. Nothing in this Agreement permits you to
192
+ make use of Licensor's trademarks, trade names, logos or to
193
+ otherwise suggest endorsement or misrepresent the relationship
194
+ between the parties; and any rights not expressly granted herein
195
+ are reserved by the Licensor.
196
+
197
+ 9. Disclaimer of Warranty. Unless required by applicable law or
198
+ agreed to in writing, Licensor provides LTX-2 on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
200
+ implied, including, without limitation, any warranties or
201
+ conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS
202
+ FOR A PARTICULAR PURPOSE. You are solely responsible for
203
+ determining the appropriateness of using or redistributing LTX-2
204
+ and Derivatives of LTX-2 and assume any risks associated with
205
+ your exercise of permissions under this Agreement.
206
+
207
+ 10. Limitation of Liability. In no event and under no legal theory,
208
+ whether in tort (including negligence), contract, or otherwise,
209
+ unless required by applicable law (such as deliberate and grossly
210
+ negligent acts) or agreed to in writing, shall Licensor be liable
211
+ to you for damages, including any direct, indirect, special,
212
+ incidental, or consequential damages of any character arising as
213
+ a result of this Agreement or out of the use or inability to use
214
+ LTX-2 (including but not limited to damages for loss of goodwill,
215
+ work stoppage, computer failure or malfunction, or any and all
216
+ other commercial damages or losses), even if Licensor has been
217
+ advised of the possibility of such damages.
218
+
219
+ 11. Accepting Warranty or Additional Liability. While redistributing
220
+ LTX-2 and Derivatives of LTX-2, you may, provided you do not
221
+ violate the terms of this Agreement, choose to offer and charge
222
+ a fee for, acceptance of support, warranty, indemnity, or other
223
+ liability obligations. However, in accepting such obligations,
224
+ you may act only on your own behalf and on your sole
225
+ responsibility, not on behalf of Licensor, and only if you agree
226
+ to indemnify, defend, and hold Licensor harmless for any liability
227
+ incurred by, or claims asserted against Licensor, by reason of
228
+ your accepting any such warranty or additional liability.
229
+
230
+ 12. Governing Law. This Agreement and all relations, disputes, claims
231
+ and other matters arising hereunder (including non-contractual
232
+ disputes or claims) will be governed exclusively by, and construed
233
+ exclusively in accordance with, the laws of the State of New York.
234
+ To the extent permitted by law, choice of laws rules and the
235
+ United Nations Convention on Contracts for the International Sale
236
+ of Goods will not apply. For the purposes of adjudicating any
237
+ action or proceeding to enforce the terms of this Agreement, you
238
+ hereby irrevocably consent to the exclusive jurisdiction of, and
239
+ venue in, the federal and state courts located in the County of
240
+ New York within the State of New York. The prevailing party in
241
+ any claim or dispute between the parties under this Agreement
242
+ will be entitled to reimbursement of its reasonable attorneys'
243
+ fees and costs. You hereby waive the right to a trial by jury,
244
+ to participate in a class or representative action (including in
245
+ arbitration), or to combine individual proceedings in court or
246
+ in arbitration without the consent of all parties.
247
+
248
+ 13. Term and Termination. This Agreement is effective upon your
249
+ acceptance and continues until terminated. Licensor may terminate
250
+ this Agreement immediately upon written notice to you if you
251
+ breach any provision of this Agreement, including but not limited
252
+ to violations of the use restrictions in Attachment A or
253
+ unauthorized commercial use. Upon termination: (a) all rights
254
+ granted to you under this Agreement will immediately cease;
255
+ (b) you must immediately cease all use of LTX-2 and Derivatives
256
+ of LTX-2; (c) you must delete or destroy all copies of LTX-2
257
+ and Derivatives of LTX-2 in your possession or control; and
258
+ (d) you must notify any third parties to whom you distributed
259
+ LTX-2 or Derivatives of LTX-2 of the termination. Sections 8-13,
260
+ and Section 15 shall survive termination of this Agreement.
261
+ Termination does not relieve you of any obligations incurred
262
+ prior to termination, including payment obligations under
263
+ Section 2. In addition, if You commence a lawsuit or other
264
+ proceedings (including a cross-claim or counterclaim in a lawsuit)
265
+ against Licensor or any person or entity alleging that LTX-2 or
266
+ any Output, or any portion of any of the foregoing, infringe any
267
+ intellectual property or other right owned or licensable by you,
268
+ then all licenses granted to you under this Agreement shall
269
+ terminate as of the date such lawsuit or other proceeding is filed.
270
+
271
+ 14. Disputes and Arbitration. All disputes arising in connection with
272
+ this Agreement shall be finally settled by arbitration under the
273
+ Rules of Arbitration of the International Chamber of Commerce
274
+ ("ICC Rules"), by one (1) arbitrator appointed in accordance with
275
+ the ICC Rules. The seat of arbitration shall be New York, NY, USA,
276
+ and the proceedings shall be conducted in English. The arbitrator
277
+ shall be empowered to grant any relief that a court could grant.
278
+ Judgment on the arbitration award may be entered by any court
279
+ having jurisdiction thereof. Each party waives its right to a
280
+ trial by jury and to participate in any class or representative
281
+ action.
282
+
283
+ 15. If any provision of this Agreement is held to be
284
+ invalid, illegal
285
+ or unenforceable, the remaining provisions shall be unaffected
286
+ thereby and remain valid as if such provision had not been set
287
+ forth herein.
288
+
289
+ END OF TERMS AND CONDITIONS
290
+
291
+ ATTACHMENT A: Use Restrictions
292
+
293
+ When using the Outputs, LTX-2 and any Derivatives thereof, you
294
+ will comply with the Acceptable Use Policy. In addition, you
295
+ agree not to use the Outputs, LTX-2 or its Derivatives in any
296
+ of the following ways:
297
+
298
+ 1. In any way that violates any applicable national, federal,
299
+ state, local or international law or regulation;
300
+
301
+ 2. For the purpose of exploiting, Harming or attempting to
302
+ exploit or Harm minors in any way;
303
+
304
+ 3. To generate or disseminate false information and/or content
305
+ with the purpose of Harming others;
306
+
307
+ 4. To generate or disseminate personal identifiable information
308
+ that can be used to Harm an individual;
309
+
310
+ 5. To generate or disseminate information and/or content (e.g.
311
+ images, code, posts, articles), and place the information
312
+ and/or content in any context (e.g. bot generating tweets)
313
+ without expressly and intelligibly disclaiming that the
314
+ information and/or content is machine generated;
315
+
316
+ 6. To defame, disparage or otherwise harass others;
317
+
318
+ 7. To impersonate or attempt to impersonate (e.g. deepfakes)
319
+ others without their consent;
320
+
321
+ 8. For fully automated decision making that adversely impacts an
322
+ individual's legal rights or otherwise creates or modifies a
323
+ binding, enforceable obligation;
324
+
325
+ 9. For any use intended to or which has the effect of
326
+ discriminating against or Harming individuals or groups based
327
+ on online or offline social behavior or known or predicted
328
+ personal or personality characteristics;
329
+
330
+ 10. To exploit any of the vulnerabilities of a specific group of
331
+ persons based on their age, social, physical or mental
332
+ characteristics, in order to materially distort the behavior
333
+ of a person pertaining to that group in a manner that causes
334
+ or is likely to cause that person or another person physical
335
+ or psychological Harm;
336
+
337
+ 11. For any use intended to or which has the effect of
338
+ discriminating against individuals or groups based on legally
339
+ protected characteristics or categories;
340
+
341
+ 12. To provide medical advice and medical results interpretation;
342
+
343
+ 13. To generate or disseminate information for the purpose to be
344
+ used for administration of justice, law enforcement,
345
+ immigration or asylum processes, such as predicting an
346
+ individual will commit fraud/crime commitment (e.g. by text
347
+ profiling, drawing causal relationships between assertions
348
+ made in documents, indiscriminate and arbitrarily-targeted use);
349
+
350
+ 14. To generate and/or disseminate malware (including – but not
351
+ limited to – ransomware) or any other content to be used for
352
+ the purpose of harming electronic systems;
353
+
354
+ 15. To engage in, promote, incite, or facilitate discrimination
355
+ or other unlawful or harmful conduct in the provision of
356
+ employment, employment benefits, credit, housing, or other
357
+ essential goods and services;
358
+
359
+ 16. To engage in, promote, incite, or facilitate the harassment,
360
+ abuse, threatening, or bullying of individuals or groups of
361
+ individuals;
362
+
363
+ 17. For military, warfare, nuclear industries or applications,
364
+ weapons development, or any use in connection with activities
365
+ that may cause death, personal injury, or severe physical or
366
+ environmental damage;
367
+
368
+ 18. For commercial use only: To train, improve, or fine-tune any
369
+ other machine learning model, artificial intelligence system,
370
+ or competing model, except for Derivatives of LTX-2 as
371
+ expressly permitted under this Agreement;
372
+
373
+ 19. To circumvent, disable, or interfere with any technical
374
+ limitations, safety features, content filters, or use
375
+ restrictions implemented in LTX-2 by Licensor;
376
+
377
+ 20. To use LTX-2 or Derivatives of LTX-2 in any product, service,
378
+ or application that directly competes with Licensor's
379
+ commercial products or services, or is designed to replace or
380
+ substitute Licensor's offerings in the market, without
381
+ obtaining a separate commercial license from Licensor.
README.md CHANGED
@@ -1,42 +1,162 @@
1
- ---
2
- title: DramaBox
3
- emoji: 🎭
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.44.1
8
- app_file: app.py
9
- pinned: true
10
- license: other
11
- license_name: ltx-2-community
12
- license_link: https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE
13
- hardware: l40s
14
- short_description: Expressive TTS with voice cloning — DramaBox demo
15
- ---
16
 
17
- # DramaBox Expressive TTS Demo
18
 
19
- Live demo of [`ResembleAI/Dramabox`](https://huggingface.co/ResembleAI/Dramabox). Write a scene prompt, optionally upload a 10-second voice reference, and generate. Audio is automatically watermarked with [Resemble Perth](https://github.com/resemble-ai/Perth).
20
 
21
- The model checkpoints download automatically on first launch.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- ## Prompt format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
25
  ```
26
- <speaker description>, "<dialogue>" <action direction> "<more dialogue>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ```
28
 
29
- - **Inside double quotes**: dialogue and phonetic sounds (`"Hahaha"`, `"Mmmmm"`, `"Ugh"`)
30
- - **Outside quotes**: stage directions (`She sighs.`, `He clears his throat.`)
31
- - **Avoid inside quotes**: `Ahem`, `Pfft`, `Sigh`, `Gasp`, `Cough` the model will speak them literally.
 
 
 
 
 
 
 
32
 
33
- See the **Load an example prompt** dropdown for ready-made scene templates.
34
 
35
- ## Files
36
 
37
- - `app.py`Gradio UI
38
- - `src/inference_server.py` — warm `TTSServer` (single load, ~2.5s/request)
39
- - `src/inference.py` — CLI inference
40
- - `src/model_downloader.py` — auto-fetches model from HuggingFace
41
- - `ltx2/` — vendored LTX-2 pipelines
42
- - `requirements.txt` — Python deps (includes `resemble-perth`)
 
1
+ # Dramabox - Expressive TTS with Voice Cloning
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ Prompt-driven TTS with voice cloning built on a 3.3B Diffusion Transformer with flow matching.
4
 
5
+ ## Folder Structure
6
 
7
+ ```
8
+ DramaBox/
9
+ ├── src/
10
+ │ ├── inference.py # TTS inference with voice cloning
11
+ │ ├── inference_server.py # Warm server (~2.5s per generation)
12
+ │ ├── audio_conditioning.py # Reference audio conditioning
13
+ │ └── model_downloader.py # Auto-download models from HuggingFace
14
+ ├── patches/
15
+ │ ├── attention.py # dtype fix for mask allocation
16
+ │ └── guiders.py # Per-token CFG clamping
17
+ ├── assets/
18
+ │ └── silence_latent_frame.pt
19
+ ├── evals/
20
+ │ ├── eval_short.txt # 30 short prompts (~5-15s)
21
+ │ ├── eval_long.txt # 15 long prompts (~20-37s)
22
+ │ └── eval_expressive.txt # 15 expressive prompts (laughs, sighs, stammers)
23
+ ├── scripts/
24
+ │ ├── inference.sh # Inference wrapper
25
+ │ └── eval.sh # Evaluation runner
26
+ ├── app.py # Gradio demo app
27
+ ├── ltx2/ # LTX-2 dependency packages
28
+ └── README.md
29
+ ```
30
+
31
+ ## Models
32
+
33
+ Models auto-download from [ResembleAI/Dramabox](https://huggingface.co/ResembleAI/Dramabox) on HuggingFace.
34
+
35
+ | Model | Size | Description |
36
+ |-------|------|-------------|
37
+ | `dramabox-dit-v1.safetensors` | 6.6 GB | DiT transformer |
38
+ | `dramabox-audio-components.safetensors` | 2.7 GB | Audio VAE + vocoder + text projection |
39
+ | [unsloth/gemma-3-12b-it-bnb-4bit](https://huggingface.co/unsloth/gemma-3-12b-it-bnb-4bit) | ~8 GB | Text encoder (auto-downloaded) |
40
+
41
+ **VRAM**: ~24 GB peak | **Speed**: ~2.5s per generation (warm server, H100)
42
+
43
+ ## Quick Start
44
+
45
+ ### Warm Server (recommended, ~2.5s per request)
46
+
47
+ ```python
48
+ from src.inference_server import TTSServer
49
+
50
+ server = TTSServer(device="cuda")
51
+
52
+ server.generate_to_file(
53
+ prompt='A woman speaks warmly, "Hello, how are you today?" She laughs, "Hahaha, it is so good to see you!"',
54
+ output="output.wav",
55
+ voice_ref="reference.wav", # optional, 10+ seconds
56
+ )
57
+ ```
58
+
59
+ ### Gradio App
60
+
61
+ ```bash
62
+ GEMINI_API_KEY=your_key CUDA_VISIBLE_DEVICES=4 python app.py
63
+ ```
64
+
65
+ ### CLI Inference
66
+
67
+ ```bash
68
+ python src/inference.py \
69
+ --voice-sample reference.wav \
70
+ --prompt 'A woman speaks warmly, "Hello, how are you today?"' \
71
+ --output output.wav \
72
+ --cfg-scale 2.5 --stg-scale 1.5
73
+ ```
74
+
75
+ ### Evaluation
76
 
77
+ ```bash
78
+ bash scripts/eval.sh --eval expressive --output eval_results/
79
+ ```
80
+
81
+ ## Inference Settings
82
+
83
+ | Parameter | Default | Notes |
84
+ |-----------|---------|-------|
85
+ | cfg-scale | 2.5 | Lower = more natural, higher = more text following |
86
+ | stg-scale | 1.5 | Skip-token guidance |
87
+ | rescale | 0 | No rescaling |
88
+ | modality | 1 | No modality guidance |
89
+ | duration-multiplier | 1.1 | 10% breathing room |
90
+ | steps | 30 | Euler flow matching |
91
+
92
+ ## Prompt Writing Guide
93
+
94
+ **Structure:** `<speaker description>, "<dialogue>" <action direction> "<more dialogue>"`
95
+
96
+ ### What works inside quotes (model produces actual sounds)
97
+ - Laughs: `"Hahaha"` `"Hehehe"` (always one word, never separated)
98
+ - Sounds: `"Mmmmm"` `"Ugh"` `"Argh"` `"Ahhh"` `"Hmm"`
99
+
100
+ ### What goes outside quotes (stage directions)
101
+ - `She sighs deeply.` `He gulps nervously.` `A long pause.`
102
+ - `Her voice cracks.` `He clears his throat.` `She scoffs.`
103
+
104
+ ### Never inside quotes (model speaks them literally)
105
+ - Ahem, Pfft, Sigh, Gasp, Cough
106
+
107
+ ### Tips
108
+ - Match gender/age in speaker description to voice reference
109
+ - Break long dialogue into segments with acting directions between them
110
+ - End prompt at the last closing quote mark (no trailing descriptions)
111
+
112
+ ## Watermarking
113
 
114
+ Every audio output from `inference.py` and `inference_server.TTSServer.generate_to_file` is automatically watermarked with [Resemble Perth](https://github.com/resemble-ai/Perth) — an imperceptible neural watermark that survives MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
115
+
116
+ ```python
117
+ import perth, librosa
118
+ wav, sr = librosa.load("output.wav", sr=None, mono=True)
119
+ detector = perth.PerthImplicitWatermarker()
120
+ print(detector.get_watermark(wav, sample_rate=sr)) # confidence ≈ 1.0 for our outputs
121
  ```
122
+
123
+ Pass `--no-watermark` to `inference.py` (or `watermark=False` to `generate_to_file`) to disable for debugging.
124
+
125
+ ## Training
126
+
127
+ DramaBox is an IC-LoRA fine-tune of the LTX-2.3 22B audio-only branch. To train your own:
128
+
129
+ ```bash
130
+ # 1. Preprocess raw (audio, transcript) pairs → audio_latents/ + conditions/
131
+ python src/preprocess.py \
132
+ --dataset-type manifest \
133
+ --index your_data.jsonl \
134
+ --output-dir /path/to/preprocessed/ \
135
+ --checkpoint dramabox-audio-components.safetensors \
136
+ --gemma-root /path/to/gemma-3-12b-it-bnb-4bit/
137
+
138
+ # 2. Edit configs/training_args.example.yaml → your data paths
139
+
140
+ # 3. Launch (uses HuggingFace accelerate)
141
+ bash scripts/train.sh \
142
+ --config configs/training_args.example.yaml \
143
+ --gpus 0,1,2,3,4,5,6 \
144
+ --train-val-gpu 7
145
  ```
146
 
147
+ | Script | Purpose |
148
+ |---|---|
149
+ | `src/preprocess.py` | Encode audio (Audio VAE) + text (Gemma) into training-ready `.pt` files |
150
+ | `src/train.py` | IC-LoRA training loop with peft, accelerate multi-GPU, periodic validation |
151
+ | `src/validate.py` | Spawned by `train.py` at each save step; runs the warm validator on a held-out prompt set |
152
+ | `scripts/train.sh` | YAML-config wrapper around `accelerate launch src/train.py` |
153
+
154
+ LoRA targets the audio branch only: `audio_attn1.{to_q,to_k,to_v,to_out.0}` + `audio_ff.{net.0.proj,net.2}` × 48 transformer blocks (288 LoRA pairs total). Default rank 128 / alpha 128 / dropout 0.1, cosine LR schedule from 1e-4 with 500-step warmup over 10k steps.
155
+
156
+ ## Language
157
 
158
+ English.
159
 
160
+ ## License
161
 
162
+ Built on [LTX-2](https://github.com/Lightricks/LTX-2) by Lightricks. Distributed under the LTX-2 Community License Agreement see [`LICENSE`](LICENSE).
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """DramaBox — Gradio demo (warm server).
3
+
4
+ Loads the warm TTSServer once, then handles requests at ~2.5 s each. All
5
+ generated audio is invisibly watermarked with Resemble Perth before being
6
+ returned to the user.
7
+ """
8
+ import logging
9
+ import os
10
+ import sys
11
+ import tempfile
12
+ import time
13
+
14
+ import gradio as gr
15
+
16
+ # Local src import.
17
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
18
+ from inference_server import TTSServer # noqa: E402
19
+
20
+
21
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
22
+ logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
23
+ tts = TTSServer(
24
+ device="cuda",
25
+ dtype=os.environ.get("LTX_DTYPE", "bf16"),
26
+ compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
27
+ bnb_4bit=True, # default Gemma is unsloth pre-quantized
28
+ )
29
+ logging.info("Server ready.")
30
+
31
+
32
+ # ── Example prompts (shown as click-to-fill chips in the UI) ─────────────────
33
+ EXAMPLES: list[tuple[str, str]] = [
34
+ (
35
+ "Villain monologue",
36
+ 'A shadowy villain speaks with cold menace, "You have entered my domain, mortal." '
37
+ 'He chuckles darkly, "Such arrogance will be your undoing." '
38
+ 'His voice rises with fury, "Kneel, or be destroyed where you stand!"'
39
+ ),
40
+ (
41
+ "Talk-show host wheeze-laugh",
42
+ 'A talk show host gasps with shock, "No! You did NOT just say that!" '
43
+ 'He bursts into uncontrollable laughter, "Hahaha! Oh my god, oh my god!" '
44
+ 'He wheezes, "I cannot, I literally cannot breathe right now!"'
45
+ ),
46
+ (
47
+ "Tender goodnight whisper",
48
+ 'A woman speaks tenderly, "It has been a long day, my love." '
49
+ 'She whispers, "Close your eyes. I am right here." '
50
+ 'She hums quietly, "Mmmm-mmm. Sleep now."'
51
+ ),
52
+ (
53
+ "Old-school radio anchor",
54
+ 'A radio host clears his throat, "Excuse me, pardon that." '
55
+ 'He settles into a warm, professional tone, "Good evening everyone, '
56
+ 'and welcome back to the show. We have got a wonderful lineup tonight."'
57
+ ),
58
+ (
59
+ "Catgirl uncontrollable giggling",
60
+ 'A playful girl already mid-giggle, "Hehehe, oh my gosh you should see your face!" '
61
+ 'She gasps for air between giggles, "Oh my, hehe, oh my, I cannot stop!" '
62
+ 'She tries to compose herself, "Ahhhhh okay okay okay, I will stop, I promise."'
63
+ ),
64
+ (
65
+ "Hero stammering courage",
66
+ 'A young warrior speaks with a trembling voice, "I... I do not know if I can do this." '
67
+ 'He takes a shaky breath, "But someone has to try." '
68
+ 'His voice steadies with growing fire, "No more running. I WILL fight!"'
69
+ ),
70
+ (
71
+ "Exhausted dad, fraying patience",
72
+ 'An exhausted father speaks with fraying patience, "Sweetie, daddy is asking very nicely." '
73
+ 'He sighs deeply, "Ohhhh my goodness." '
74
+ 'He puts on an overly cheerful voice, "Hey buddy! Look at the shiny thing!" '
75
+ 'Then he laughs helplessly, "Hahaha, I am losing my mind."'
76
+ ),
77
+ (
78
+ "Smug-confident announcer",
79
+ 'A confident announcer speaks proudly, "And now, the moment you have all been waiting for." '
80
+ 'He chuckles knowingly, "Heheh, trust me, this one is going to blow you away."'
81
+ ),
82
+ ]
83
+
84
+
85
+ def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float, seed: int):
86
+ if not prompt or not prompt.strip():
87
+ raise gr.Error("Prompt is empty.")
88
+ t0 = time.time()
89
+ ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
90
+ output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
91
+ tts.generate_to_file(
92
+ prompt=prompt,
93
+ output=output,
94
+ voice_ref=ref_path,
95
+ cfg_scale=cfg, stg_scale=stg,
96
+ duration_multiplier=dur_mult, seed=int(seed),
97
+ )
98
+ elapsed = time.time() - t0
99
+ logging.info(f"Generated in {elapsed:.2f}s -> {output}")
100
+ return output
101
+
102
+
103
+ # ── UI ──────────────────────────────────────────────────────────────────────
104
+ with gr.Blocks(
105
+ title="DramaBox — Expressive TTS",
106
+ theme=gr.themes.Default(),
107
+ css=".prompt-box textarea { font-size: 14px !important; line-height: 1.5 !important; }",
108
+ ) as app:
109
+ gr.Markdown("# 🎭 DramaBox — Expressive TTS with Voice Cloning")
110
+ gr.Markdown(
111
+ "Write a scene prompt, optionally upload a 10-second voice reference, "
112
+ "and generate. Audio is automatically watermarked with "
113
+ "[Resemble Perth](https://github.com/resemble-ai/Perth).\n\n"
114
+ "**Tips:** put dialogue inside `\"double quotes\"`, scene directions outside. "
115
+ "Phonetic sounds (`\"Hahaha\"`, `\"Mmmm\"`, `\"Ugh\"`) go inside quotes; named "
116
+ "actions (`She sighs.`, `He clears his throat.`) go outside."
117
+ )
118
+
119
+ with gr.Row():
120
+ with gr.Column(scale=3):
121
+ prompt_box = gr.Textbox(
122
+ label="Scene prompt",
123
+ placeholder=EXAMPLES[0][1],
124
+ lines=6, elem_classes=["prompt-box"],
125
+ )
126
+ example_chooser = gr.Dropdown(
127
+ choices=[e[0] for e in EXAMPLES],
128
+ label="Load an example prompt", interactive=True, value=None,
129
+ )
130
+ audio_ref = gr.Audio(
131
+ label="Voice reference (optional, 10+ seconds)",
132
+ type="filepath",
133
+ )
134
+ gen_btn = gr.Button("Generate", variant="primary", size="lg")
135
+
136
+ with gr.Column(scale=2):
137
+ with gr.Accordion("Inference settings", open=False):
138
+ cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
139
+ stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
140
+ dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05, label="Duration ×")
141
+ seed_input = gr.Number(value=42, label="Seed", precision=0)
142
+ audio_out = gr.Audio(label="Generated audio", type="filepath")
143
+ with gr.Accordion("Prompt writing guide", open=False):
144
+ gr.Markdown(
145
+ "**Structure:** `<speaker description>, \"<dialogue>\" <action> \"<more dialogue>\"`\n\n"
146
+ "**Inside quotes** (model speaks them):\n"
147
+ "- Dialogue: `\"Hello, how are you?\"`\n"
148
+ "- Phonetic sounds: `\"Hahaha\"`, `\"Hehehe\"`, `\"Mmmmm\"`, `\"Ugh\"`, `\"Argh\"`\n\n"
149
+ "**Outside quotes** (stage directions):\n"
150
+ "- `She sighs deeply.`, `He gulps nervously.`, `A long pause.`\n"
151
+ "- `Her voice cracks.`, `He clears his throat.`\n\n"
152
+ "**Avoid inside quotes:** Ahem, Pfft, Sigh, Gasp, Cough — the model speaks them literally."
153
+ )
154
+
155
+ def _load_example(choice: str):
156
+ if not choice:
157
+ return gr.update()
158
+ for name, prompt in EXAMPLES:
159
+ if name == choice:
160
+ return prompt
161
+ return gr.update()
162
+
163
+ example_chooser.change(_load_example, inputs=[example_chooser], outputs=[prompt_box])
164
+ gen_btn.click(
165
+ on_generate,
166
+ inputs=[prompt_box, audio_ref, cfg_slider, stg_slider, dur_slider, seed_input],
167
+ outputs=[audio_out],
168
+ )
169
+
170
+
171
+ if __name__ == "__main__":
172
+ port = int(os.environ.get("GRADIO_SERVER_PORT", "7861"))
173
+ app.queue(max_size=10).launch(
174
+ server_name="0.0.0.0", server_port=port,
175
+ share=os.environ.get("GRADIO_SHARE", "0") == "1",
176
+ )
assets/silence_latent_frame.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f73746d2163f8f1742c5de89005404ccaeeff05154bbb10a3337bf9bd13f161c
3
+ size 1501
configs/training_args.example.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example DramaBox IC-LoRA training config. Used by scripts/train.sh.
2
+
3
+ # Where to load preprocessed `audio_latents/` + `conditions/` shards from.
4
+ data_dir:
5
+ - /path/to/preprocessed_dataset_a/
6
+ - /path/to/preprocessed_dataset_b/
7
+
8
+ # One index file per data_dir entry. Each line:
9
+ # <sample_id>~<speaker_id>~<lang>~<sample_rate>~<offset>~<duration>~<phonemes>~<text>
10
+ speaker_index:
11
+ - /path/to/preprocessed_dataset_a/index.txt
12
+ - /path/to/preprocessed_dataset_b/index.txt
13
+
14
+ # Output directory (relative is fine — resolved against the repo root).
15
+ output_dir: tts_iclora_v1
16
+
17
+ # LTX-2.3 22B base. Same file is used for the transformer + the aux stack
18
+ # (PromptEncoder, AudioVAE, AudioDecoder).
19
+ checkpoint: ltx-2.3-22b-dev.safetensors
20
+ full_checkpoint: ltx-2.3-22b-dev.safetensors
21
+ base_model: dev
22
+
23
+ # LoRA hyperparams. rank == alpha is the simplest setup (scale = 1.0).
24
+ lora_rank: 128
25
+ lora_alpha: 128
26
+ lora_dropout: 0.1
27
+
28
+ # Voice-cloning ref-token settings.
29
+ ref_ratio: 0.3 # fraction of training samples that get a ref token
30
+ max_ref_tokens: 200 # max ref-token positions appended to target
31
+
32
+ text_dropout: 0.4 # CFG training: drop the text prompt with prob 0.4
33
+
34
+ # Schedule. Use lr_scheduler=constant with a small lr (1e-5) for a "fine-tune"
35
+ # resume; cosine + larger lr (1e-4) for from-scratch.
36
+ steps: 10000
37
+ lr: 1.0e-04
38
+ lr_scheduler: cosine
39
+ warmup_steps: 500
40
+
41
+ batch_size: 1
42
+ grad_accum: 4
43
+ max_grad_norm: 1.0
44
+
45
+ save_every: 500
46
+ log_every: 50
47
+ seed: 53
48
+
49
+ # (Optional) per-checkpoint validation eval — see configs/val_config.example.yaml
50
+ # val_config: val_config.example.yaml
51
+
52
+ # (Optional) resume from a previous LoRA adapter file:
53
+ # resume_lora: tts_iclora_v0/lora_step_05000.safetensors
ltx2/ltx_core/__init__.py ADDED
File without changes
ltx2/ltx_core/batch_split.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Batch-splitting adapter for the transformer.
2
+ Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs
3
+ into smaller chunks before forwarding, then concatenates the results. This
4
+ controls peak activation memory at the cost of more forward passes.
5
+ The adapter is transparent — it has the same ``forward`` signature as
6
+ ``X0Model`` and proxies attribute access to the wrapped model.
7
+ Example
8
+ -------
9
+ >>> from ltx_core.batch_split import BatchSplitAdapter
10
+ >>> adapter = BatchSplitAdapter(model, max_batch_size=1)
11
+ >>> # Receives B=4, runs 4xB=1 internally, returns B=4
12
+ >>> denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from ltx_core.guidance.perturbations import BatchedPerturbationConfig
23
+ from ltx_core.model.transformer.modality import Modality
24
+
25
+
26
+ def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]:
27
+ """Split a ``BatchedPerturbationConfig`` along the batch dimension."""
28
+ it = iter(config.perturbations)
29
+ return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes]
30
+
31
+
32
+ def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None:
33
+ """Concatenate tensors along batch dim, or return None if all are None."""
34
+ non_none = [t for t in tensors if t is not None]
35
+ if not non_none:
36
+ return None
37
+ return torch.cat(non_none, dim=0)
38
+
39
+
40
+ class BatchSplitAdapter(nn.Module):
41
+ """Wraps a model and splits batched forward calls into smaller chunks.
42
+ Has the same ``forward`` signature as ``X0Model``:
43
+ ``(video, audio, perturbations) -> (denoised_video, denoised_audio)``.
44
+ Args:
45
+ model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.).
46
+ max_batch_size: Maximum batch size per forward pass. Input batches
47
+ larger than this are split into sequential chunks.
48
+ """
49
+
50
+ def __init__(self, model: nn.Module, max_batch_size: int) -> None:
51
+ if max_batch_size < 1:
52
+ raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}")
53
+ super().__init__()
54
+ self._model = model
55
+ self._max_batch_size = max_batch_size
56
+
57
+ def _get_chunk_sizes(self, batch_size: int) -> list[int]:
58
+ full, remainder = divmod(batch_size, self._max_batch_size)
59
+ sizes = [self._max_batch_size] * full
60
+ if remainder:
61
+ sizes.append(remainder)
62
+ return sizes
63
+
64
+ def forward(
65
+ self,
66
+ video: Modality | None,
67
+ audio: Modality | None,
68
+ perturbations: BatchedPerturbationConfig,
69
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
70
+ batch_size = (video or audio).latent.shape[0]
71
+
72
+ if batch_size <= self._max_batch_size:
73
+ return self._model(video=video, audio=audio, perturbations=perturbations)
74
+
75
+ sizes = self._get_chunk_sizes(batch_size)
76
+ n = len(sizes)
77
+
78
+ v_chunks = video.split(sizes) if video is not None else [None] * n
79
+ a_chunks = audio.split(sizes) if audio is not None else [None] * n
80
+ p_chunks = _split_perturbations(perturbations, sizes)
81
+
82
+ chunk_results = [
83
+ self._model(video=vc, audio=ac, perturbations=pc)
84
+ for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True)
85
+ ]
86
+
87
+ results_v, results_a = zip(*chunk_results, strict=True)
88
+ return _merge_tensors(list(results_v)), _merge_tensors(list(results_a))
89
+
90
+ def __getattr__(self, name: str) -> Any: # noqa: ANN401
91
+ """Proxy attribute access to the wrapped model."""
92
+ try:
93
+ return super().__getattr__(name)
94
+ except AttributeError:
95
+ return getattr(self._model, name)
ltx2/ltx_core/components/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diffusion pipeline components.
3
+ Submodules:
4
+ diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
5
+ guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
6
+ noisers - Noise samplers (GaussianNoiser)
7
+ patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
8
+ protocols - Protocol definitions (Patchifier, etc.)
9
+ schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
10
+ """
ltx2/ltx_core/components/diffusion_steps.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.protocols import DiffusionStepProtocol
4
+ from ltx_core.utils import to_velocity
5
+
6
+
7
+ class EulerDiffusionStep(DiffusionStepProtocol):
8
+ """
9
+ First-order Euler method for diffusion sampling.
10
+ Takes a single step from the current noise level (sigma) to the next by
11
+ computing velocity from the denoised prediction and applying: sample + velocity * dt.
12
+ """
13
+
14
+ def step(
15
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs
16
+ ) -> torch.Tensor:
17
+ sigma = sigmas[step_index]
18
+ sigma_next = sigmas[step_index + 1]
19
+ dt = sigma_next - sigma
20
+ velocity = to_velocity(sample, sigma, denoised_sample)
21
+
22
+ return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
23
+
24
+
25
+ class Res2sDiffusionStep(DiffusionStepProtocol):
26
+ """
27
+ Second-order diffusion step for res_2s sampling with SDE noise injection.
28
+ Used by the res_2s denoising loop. Advances the sample from the current
29
+ sigma to the next by mixing a deterministic update (from the denoised
30
+ prediction) with injected noise via ``get_sde_coeff``, producing
31
+ variance-preserving transitions.
32
+ """
33
+
34
+ @staticmethod
35
+ def get_sde_coeff(
36
+ sigma_next: torch.Tensor,
37
+ sigma_up: torch.Tensor | None = None,
38
+ sigma_down: torch.Tensor | None = None,
39
+ sigma_max: torch.Tensor | None = None,
40
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
41
+ """
42
+ Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step.
43
+ Given either ``sigma_down`` or ``sigma_up``, returns the mixing
44
+ coefficients used for variance-preserving noise injection. If
45
+ ``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are
46
+ derived; if ``sigma_down`` is provided, ``sigma_up`` and
47
+ ``alpha_ratio`` are derived.
48
+ """
49
+ if sigma_down is not None:
50
+ alpha_ratio = (1 - sigma_next) / (1 - sigma_down)
51
+ sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5
52
+ elif sigma_up is not None:
53
+ # Fallback to avoid sqrt(neg_num)
54
+ sigma_up.clamp_(max=sigma_next * 0.9999)
55
+ sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next)
56
+ sigma_signal = sigmax - sigma_next
57
+ sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5
58
+ alpha_ratio = sigma_signal + sigma_residual
59
+ sigma_down = sigma_residual / alpha_ratio
60
+ else:
61
+ alpha_ratio = torch.ones_like(sigma_next)
62
+ sigma_down = sigma_next
63
+ sigma_up = torch.zeros_like(sigma_next)
64
+
65
+ sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0)
66
+ # Replace NaNs in sigma_down with corresponding sigma_next elements (float32)
67
+ nan_mask = torch.isnan(sigma_down)
68
+ sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype)
69
+ alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0)
70
+
71
+ return alpha_ratio, sigma_down, sigma_up
72
+
73
+ def step(
74
+ self,
75
+ sample: torch.Tensor,
76
+ denoised_sample: torch.Tensor,
77
+ sigmas: torch.Tensor,
78
+ step_index: int,
79
+ noise: torch.Tensor,
80
+ eta: float = 0.5,
81
+ ) -> torch.Tensor:
82
+ """Advance one step with SDE noise injection via get_sde_coeff.
83
+ Args:
84
+ sample: Current noisy sample.
85
+ denoised_sample: Denoised prediction from the model.
86
+ sigmas: Noise schedule tensor.
87
+ step_index: Current step index in the schedule.
88
+ noise: Random noise tensor for stochastic injection.
89
+ eta: Controls stochastic noise injection strength (0=deterministic, 1=maximum). Default 0.5.
90
+ Returns:
91
+ Next sample with SDE noise injection applied.
92
+ """
93
+ sigma = sigmas[step_index]
94
+ sigma_next = sigmas[step_index + 1]
95
+ alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * eta)
96
+ output_dtype = denoised_sample.dtype
97
+ if torch.any(sigma_up == 0) or torch.any(sigma_next == 0):
98
+ return denoised_sample
99
+
100
+ # Extract epsilon prediction
101
+ eps_next = (sample - denoised_sample) / (sigma - sigma_next)
102
+ denoised_next = sample - sigma * eps_next
103
+
104
+ # Mix deterministic and stochastic components
105
+ x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise
106
+ return x_noised.to(output_dtype)
ltx2/ltx_core/components/guiders.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Mapping, Sequence
3
+ from dataclasses import dataclass, field
4
+
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import GuiderProtocol
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class CFGGuider(GuiderProtocol):
12
+ """
13
+ Classifier-free guidance (CFG) guider.
14
+ Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
15
+ denoising process toward the conditioned prediction.
16
+ Attributes:
17
+ scale: Guidance strength. 1.0 means no guidance, higher values increase
18
+ adherence to the conditioning.
19
+ """
20
+
21
+ scale: float
22
+
23
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
24
+ return (self.scale - 1) * (cond - uncond)
25
+
26
+ def enabled(self) -> bool:
27
+ return self.scale != 1.0
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class CFGStarRescalingGuider(GuiderProtocol):
32
+ """
33
+ Calculates the CFG delta between conditioned and unconditioned samples.
34
+ To minimize offset in the denoising direction and move mostly along the
35
+ conditioning axis within the distribution, the unconditioned sample is
36
+ rescaled in accordance with the norm of the conditioned sample.
37
+ Attributes:
38
+ scale (float):
39
+ Global guidance strength. A value of 1.0 corresponds to no extra
40
+ guidance beyond the base model prediction. Values > 1.0 increase
41
+ the influence of the conditioned sample relative to the
42
+ unconditioned one.
43
+ """
44
+
45
+ scale: float
46
+
47
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
48
+ rescaled_neg = projection_coef(cond, uncond) * uncond
49
+ return (self.scale - 1) * (cond - rescaled_neg)
50
+
51
+ def enabled(self) -> bool:
52
+ return self.scale != 1.0
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class STGGuider(GuiderProtocol):
57
+ """
58
+ Calculates the STG delta between conditioned and perturbed denoised samples.
59
+ Perturbed samples are the result of the denoising process with perturbations,
60
+ e.g. attentions acting as passthrough for certain layers and modalities.
61
+ Attributes:
62
+ scale (float):
63
+ Global strength of the STG guidance. A value of 0.0 disables the
64
+ guidance. Larger values increase the correction applied in the
65
+ direction of (pos_denoised - perturbed_denoised).
66
+ """
67
+
68
+ scale: float
69
+
70
+ def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
71
+ return self.scale * (pos_denoised - perturbed_denoised)
72
+
73
+ def enabled(self) -> bool:
74
+ return self.scale != 0.0
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class LtxAPGGuider(GuiderProtocol):
79
+ """
80
+ Calculates the APG (adaptive projected guidance) delta between conditioned
81
+ and unconditioned samples.
82
+ To minimize offset in the denoising direction and move mostly along the
83
+ conditioning axis within the distribution, the (cond - uncond) delta is
84
+ decomposed into components parallel and orthogonal to the conditioned
85
+ sample. The `eta` parameter weights the parallel component, while `scale`
86
+ is applied to the orthogonal component. Optionally, a norm threshold can
87
+ be used to suppress guidance when the magnitude of the correction is small.
88
+ Attributes:
89
+ scale (float):
90
+ Strength applied to the component of the guidance that is orthogonal
91
+ to the conditioned sample. Controls how aggressively we move in
92
+ directions that change semantics but stay consistent with the
93
+ conditioning manifold.
94
+ eta (float):
95
+ Weight of the component of the guidance that is parallel to the
96
+ conditioned sample. A value of 1.0 keeps the full parallel
97
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
98
+ motion along the conditioning direction.
99
+ norm_threshold (float):
100
+ Minimum L2 norm of the guidance delta below which the guidance
101
+ can be reduced or ignored (depending on implementation).
102
+ This is useful for avoiding noisy or unstable updates when the
103
+ guidance signal is very small.
104
+ """
105
+
106
+ scale: float
107
+ eta: float = 1.0
108
+ norm_threshold: float = 0.0
109
+
110
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
111
+ guidance = cond - uncond
112
+ if self.norm_threshold > 0:
113
+ ones = torch.ones_like(guidance)
114
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
115
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
116
+ guidance = guidance * scale_factor
117
+ proj_coeff = projection_coef(guidance, cond)
118
+ g_parallel = proj_coeff * cond
119
+ g_orth = guidance - g_parallel
120
+ g_apg = g_parallel * self.eta + g_orth
121
+
122
+ return g_apg * (self.scale - 1)
123
+
124
+ def enabled(self) -> bool:
125
+ return self.scale != 1.0
126
+
127
+
128
+ @dataclass(frozen=False)
129
+ class LegacyStatefulAPGGuider(GuiderProtocol):
130
+ """
131
+ Calculates the APG (adaptive projected guidance) delta between conditioned
132
+ and unconditioned samples.
133
+ To minimize offset in the denoising direction and move mostly along the
134
+ conditioning axis within the distribution, the (cond - uncond) delta is
135
+ decomposed into components parallel and orthogonal to the conditioned
136
+ sample. The `eta` parameter weights the parallel component, while `scale`
137
+ is applied to the orthogonal component. Optionally, a norm threshold can
138
+ be used to suppress guidance when the magnitude of the correction is small.
139
+ Attributes:
140
+ scale (float):
141
+ Strength applied to the component of the guidance that is orthogonal
142
+ to the conditioned sample. Controls how aggressively we move in
143
+ directions that change semantics but stay consistent with the
144
+ conditioning manifold.
145
+ eta (float):
146
+ Weight of the component of the guidance that is parallel to the
147
+ conditioned sample. A value of 1.0 keeps the full parallel
148
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
149
+ motion along the conditioning direction.
150
+ norm_threshold (float):
151
+ Minimum L2 norm of the guidance delta below which the guidance
152
+ can be reduced or ignored (depending on implementation).
153
+ This is useful for avoiding noisy or unstable updates when the
154
+ guidance signal is very small.
155
+ momentum (float):
156
+ Exponential moving-average coefficient for accumulating guidance
157
+ over time. running_avg = momentum * running_avg + guidance
158
+ """
159
+
160
+ scale: float
161
+ eta: float
162
+ norm_threshold: float = 5.0
163
+ momentum: float = 0.0
164
+ # it is user's responsibility not to use same APGGuider for several denoisings or different modalities
165
+ # in order not to share accumulated average across different denoisings or modalities
166
+ running_avg: torch.Tensor | None = None
167
+
168
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
169
+ guidance = cond - uncond
170
+ if self.momentum != 0:
171
+ if self.running_avg is None:
172
+ self.running_avg = guidance.clone()
173
+ else:
174
+ self.running_avg = self.momentum * self.running_avg + guidance
175
+ guidance = self.running_avg
176
+
177
+ if self.norm_threshold > 0:
178
+ ones = torch.ones_like(guidance)
179
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
180
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
181
+ guidance = guidance * scale_factor
182
+
183
+ proj_coeff = projection_coef(guidance, cond)
184
+ g_parallel = proj_coeff * cond
185
+ g_orth = guidance - g_parallel
186
+ g_apg = g_parallel * self.eta + g_orth
187
+
188
+ return g_apg * self.scale
189
+
190
+ def enabled(self) -> bool:
191
+ return self.scale != 0.0
192
+
193
+
194
+ @dataclass(frozen=True)
195
+ class MultiModalGuiderParams:
196
+ """
197
+ Parameters for the multi-modal guider.
198
+ """
199
+
200
+ cfg_scale: float = 1.0
201
+ "CFG (Classifier-free guidance) scale controlling how strongly the model adheres to the prompt."
202
+ stg_scale: float = 0.0
203
+ "STG (Spatio-Temporal Guidance) scale controls how strongly the model reacts to the perturbation of the modality."
204
+ stg_blocks: list[int] | None = field(default_factory=list)
205
+ "Which transformer blocks to perturb for STG."
206
+ rescale_scale: float = 0.0
207
+ "Rescale scale controlling how strongly the model rescales the modality after applying other guidance."
208
+ modality_scale: float = 1.0
209
+ "Modality scale controlling how strongly the model reacts to the perturbation of the modality."
210
+ cfg_clamp_scale: float = 0.0
211
+ "Clamp guided prediction std to this multiple of conditioned prediction std. 0 = disabled."
212
+ skip_step: int = 0
213
+ "Skip step controlling how often the model skips the step."
214
+
215
+
216
+ def _params_for_sigma_from_sorted_dict(
217
+ sigma: float, params_by_sigma: Sequence[tuple[float, MultiModalGuiderParams]]
218
+ ) -> MultiModalGuiderParams:
219
+ """
220
+ Return params for the given sigma from a sorted (sigma_upper_bound -> params) structure.
221
+ Keys are sorted descending (bin upper bounds). Bin i is (key_{i+1}, key_i].
222
+ Get all keys >= sigma; use last in list (smallest such key = upper bound of bin containing sigma),
223
+ or last entry in the sequence if list is empty (sigma above max key).
224
+ """
225
+ if not params_by_sigma:
226
+ raise ValueError("params_by_sigma must be non-empty")
227
+ sigma = float(sigma)
228
+ keys_desc = [k for k, _ in params_by_sigma]
229
+ keys_ge_sigma = [k for k in keys_desc if k >= sigma]
230
+ # sigma above all keys: use first bin (max key)
231
+ key = keys_ge_sigma[-1] if keys_ge_sigma else keys_desc[0]
232
+ return next(p for k, p in params_by_sigma if k == key)
233
+
234
+
235
+ @dataclass(frozen=True)
236
+ class MultiModalGuider:
237
+ """
238
+ Multi-modal guider with constant params per instance.
239
+ For sigma-dependent params, use MultiModalGuiderFactory.build_from_sigma(sigma) to
240
+ obtain a guider for each step.
241
+ """
242
+
243
+ params: MultiModalGuiderParams
244
+ negative_context: torch.Tensor | None = None
245
+
246
+ def calculate(
247
+ self,
248
+ cond: torch.Tensor,
249
+ uncond_text: torch.Tensor | float,
250
+ uncond_perturbed: torch.Tensor | float,
251
+ uncond_modality: torch.Tensor | float,
252
+ ) -> torch.Tensor:
253
+ """
254
+ The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg,
255
+ and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned
256
+ prediction.
257
+ """
258
+ pred = (
259
+ cond
260
+ + (self.params.cfg_scale - 1) * (cond - uncond_text)
261
+ + self.params.stg_scale * (cond - uncond_perturbed)
262
+ + (self.params.modality_scale - 1) * (cond - uncond_modality)
263
+ )
264
+
265
+ if self.params.rescale_scale != 0:
266
+ factor = cond.std() / pred.std()
267
+ factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale)
268
+ pred = pred * factor
269
+
270
+ # Clamp guided prediction to prevent trajectory overshoot.
271
+ # Instead of global std (which averages over all tokens), clamp per-token.
272
+ # This catches individual tokens that overshoot even if the global std looks fine.
273
+ if self.params.cfg_clamp_scale > 0:
274
+ cfg_delta = pred - cond
275
+ # Per-token magnitude clamping
276
+ delta_norm = cfg_delta.norm(dim=-1, keepdim=True) # [B, T, 1]
277
+ cond_norm = cond.norm(dim=-1, keepdim=True)
278
+ max_norm = cond_norm * self.params.cfg_clamp_scale
279
+ # Clamp tokens where delta exceeds max
280
+ scale = torch.where(
281
+ delta_norm > max_norm,
282
+ max_norm / delta_norm.clamp(min=1e-8),
283
+ torch.ones_like(delta_norm),
284
+ )
285
+ pred = cond + cfg_delta * scale
286
+
287
+ return pred
288
+
289
+ def do_unconditional_generation(self) -> bool:
290
+ """Returns True if the guider is doing unconditional generation."""
291
+ return not math.isclose(self.params.cfg_scale, 1.0)
292
+
293
+ def do_perturbed_generation(self) -> bool:
294
+ """Returns True if the guider is doing perturbed generation."""
295
+ return not math.isclose(self.params.stg_scale, 0.0)
296
+
297
+ def do_isolated_modality_generation(self) -> bool:
298
+ """Returns True if the guider is doing isolated modality generation."""
299
+ return not math.isclose(self.params.modality_scale, 1.0)
300
+
301
+ def should_skip_step(self, step: int) -> bool:
302
+ """Returns True if the guider should skip the step."""
303
+ if self.params.skip_step == 0:
304
+ return False
305
+ return step % (self.params.skip_step + 1) != 0
306
+
307
+
308
+ @dataclass(frozen=True)
309
+ class MultiModalGuiderFactory:
310
+ """
311
+ Factory that creates a MultiModalGuider for a given sigma.
312
+ Single source of truth: _params_by_sigma (schedule). Use constant() for
313
+ one params for all sigma, from_dict() for sigma-binned params.
314
+ """
315
+
316
+ negative_context: torch.Tensor | None = None
317
+ _params_by_sigma: tuple[tuple[float, MultiModalGuiderParams], ...] = ()
318
+
319
+ @classmethod
320
+ def constant(
321
+ cls,
322
+ params: MultiModalGuiderParams,
323
+ negative_context: torch.Tensor | None = None,
324
+ ) -> "MultiModalGuiderFactory":
325
+ """Build a factory with constant params (same guider for all sigma)."""
326
+ return cls(
327
+ negative_context=negative_context,
328
+ _params_by_sigma=((float("inf"), params),),
329
+ )
330
+
331
+ @classmethod
332
+ def from_dict(
333
+ cls,
334
+ sigma_to_params: Mapping[float, MultiModalGuiderParams],
335
+ negative_context: torch.Tensor | None = None,
336
+ ) -> "MultiModalGuiderFactory":
337
+ """
338
+ Build a factory from a dict of sigma_value -> MultiModalGuiderParams.
339
+ Keys are sorted descending and used for bin lookup in params(sigma).
340
+ """
341
+ if not sigma_to_params:
342
+ raise ValueError("sigma_to_params must be non-empty")
343
+ sorted_items = tuple(sorted(sigma_to_params.items(), key=lambda x: x[0], reverse=True))
344
+ return cls(negative_context=negative_context, _params_by_sigma=sorted_items)
345
+
346
+ def params(self, sigma: float | torch.Tensor) -> MultiModalGuiderParams:
347
+ """Return params effective for the given sigma (getter; single source of truth)."""
348
+ sigma_val = float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma)
349
+ return _params_for_sigma_from_sorted_dict(sigma_val, self._params_by_sigma)
350
+
351
+ def build_from_sigma(self, sigma: float | torch.Tensor) -> MultiModalGuider:
352
+ """Return a MultiModalGuider with params effective for the given sigma."""
353
+ return MultiModalGuider(
354
+ params=self.params(sigma),
355
+ negative_context=self.negative_context,
356
+ )
357
+
358
+
359
+ def create_multimodal_guider_factory(
360
+ params: MultiModalGuiderParams | MultiModalGuiderFactory,
361
+ negative_context: torch.Tensor | None = None,
362
+ ) -> MultiModalGuiderFactory:
363
+ """
364
+ Create or return a MultiModalGuiderFactory. Pass constant params for a
365
+ single-params factory (uses MultiModalGuiderFactory.constant), or an existing
366
+ MultiModalGuiderFactory. When given a factory, returns it as-is unless
367
+ negative_context is provided. For sigma-dependent params use
368
+ MultiModalGuiderFactory.from_dict(...) and pass that as params.
369
+ """
370
+ if isinstance(params, MultiModalGuiderFactory):
371
+ if negative_context is not None and params.negative_context is not negative_context:
372
+ return MultiModalGuiderFactory.from_dict(dict(params._params_by_sigma), negative_context=negative_context)
373
+ return params
374
+ return MultiModalGuiderFactory.constant(params, negative_context=negative_context)
375
+
376
+
377
+ def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
378
+ batch_size = to_project.shape[0]
379
+ positive_flat = to_project.reshape(batch_size, -1)
380
+ negative_flat = project_onto.reshape(batch_size, -1)
381
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
382
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
383
+ return dot_product / squared_norm
ltx2/ltx_core/components/noisers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class Noiser(Protocol):
10
+ """Protocol for adding noise to a latent state during diffusion."""
11
+
12
+ def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
13
+
14
+
15
+ class GaussianNoiser(Noiser):
16
+ """Adds Gaussian noise to a latent state, scaled by the denoise mask."""
17
+
18
+ def __init__(self, generator: torch.Generator):
19
+ super().__init__()
20
+
21
+ self.generator = generator
22
+
23
+ def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
24
+ noise = torch.randn(
25
+ *latent_state.latent.shape,
26
+ device=latent_state.latent.device,
27
+ dtype=latent_state.latent.dtype,
28
+ generator=self.generator,
29
+ )
30
+ scaled_mask = latent_state.denoise_mask * noise_scale
31
+ latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
32
+ return replace(
33
+ latent_state,
34
+ latent=latent.to(latent_state.latent.dtype),
35
+ )
ltx2/ltx_core/components/patchifiers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import einops
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import Patchifier
8
+ from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
9
+
10
+
11
+ class VideoLatentPatchifier(Patchifier):
12
+ def __init__(self, patch_size: int):
13
+ # Patch sizes for video latents.
14
+ self._patch_size = (
15
+ 1, # temporal dimension
16
+ patch_size, # height dimension
17
+ patch_size, # width dimension
18
+ )
19
+
20
+ @property
21
+ def patch_size(self) -> Tuple[int, int, int]:
22
+ return self._patch_size
23
+
24
+ def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
25
+ return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
26
+
27
+ def patchify(
28
+ self,
29
+ latents: torch.Tensor,
30
+ ) -> torch.Tensor:
31
+ latents = einops.rearrange(
32
+ latents,
33
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
34
+ p1=self._patch_size[0],
35
+ p2=self._patch_size[1],
36
+ p3=self._patch_size[2],
37
+ )
38
+
39
+ return latents
40
+
41
+ def unpatchify(
42
+ self,
43
+ latents: torch.Tensor,
44
+ output_shape: VideoLatentShape,
45
+ ) -> torch.Tensor:
46
+ assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
47
+
48
+ patch_grid_frames = output_shape.frames // self._patch_size[0]
49
+ patch_grid_height = output_shape.height // self._patch_size[1]
50
+ patch_grid_width = output_shape.width // self._patch_size[2]
51
+
52
+ latents = einops.rearrange(
53
+ latents,
54
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
55
+ f=patch_grid_frames,
56
+ h=patch_grid_height,
57
+ w=patch_grid_width,
58
+ p=self._patch_size[1],
59
+ q=self._patch_size[2],
60
+ )
61
+
62
+ return latents
63
+
64
+ def get_patch_grid_bounds(
65
+ self,
66
+ output_shape: AudioLatentShape | VideoLatentShape,
67
+ device: Optional[torch.device] = None,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Return the per-dimension bounds [inclusive start, exclusive end) for every
71
+ patch produced by `patchify`. The bounds are expressed in the original
72
+ video grid coordinates: frame/time, height, and width.
73
+ The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
74
+ - axis 1 (size 3) enumerates (frame/time, height, width) dimensions
75
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
76
+ Args:
77
+ output_shape: Video grid description containing frames, height, and width.
78
+ device: Device of the latent tensor.
79
+ """
80
+ if not isinstance(output_shape, VideoLatentShape):
81
+ raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
82
+
83
+ frames = output_shape.frames
84
+ height = output_shape.height
85
+ width = output_shape.width
86
+ batch_size = output_shape.batch
87
+
88
+ # Validate inputs to ensure positive dimensions
89
+ assert frames > 0, f"frames must be positive, got {frames}"
90
+ assert height > 0, f"height must be positive, got {height}"
91
+ assert width > 0, f"width must be positive, got {width}"
92
+ assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
93
+
94
+ # Generate grid coordinates for each dimension (frame, height, width)
95
+ # We use torch.arange to create the starting coordinates for each patch.
96
+ # indexing='ij' ensures the dimensions are in the order (frame, height, width).
97
+ grid_coords = torch.meshgrid(
98
+ torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
99
+ torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
100
+ torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Stack the grid coordinates to create the start coordinates tensor.
105
+ # Shape becomes (3, grid_f, grid_h, grid_w)
106
+ patch_starts = torch.stack(grid_coords, dim=0)
107
+
108
+ # Create a tensor containing the size of a single patch:
109
+ # (frame_patch_size, height_patch_size, width_patch_size).
110
+ # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
111
+ patch_size_delta = torch.tensor(
112
+ self._patch_size,
113
+ device=patch_starts.device,
114
+ dtype=patch_starts.dtype,
115
+ ).view(3, 1, 1, 1)
116
+
117
+ # Calculate end coordinates: start + patch_size
118
+ # Shape becomes (3, grid_f, grid_h, grid_w)
119
+ patch_ends = patch_starts + patch_size_delta
120
+
121
+ # Stack start and end coordinates together along the last dimension
122
+ # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
123
+ latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
124
+
125
+ # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
126
+ # Final Shape: (batch_size, 3, num_patches, 2)
127
+ latent_coords = einops.repeat(
128
+ latent_coords,
129
+ "c f h w bounds -> b c (f h w) bounds",
130
+ b=batch_size,
131
+ bounds=2,
132
+ )
133
+
134
+ return latent_coords
135
+
136
+
137
+ def get_pixel_coords(
138
+ latent_coords: torch.Tensor,
139
+ scale_factors: SpatioTemporalScaleFactors,
140
+ causal_fix: bool = False,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
144
+ each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
145
+ Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
146
+ Args:
147
+ latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
148
+ scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
149
+ per axis.
150
+ causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
151
+ that treat frame zero differently still yield non-negative timestamps.
152
+ """
153
+ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
154
+ broadcast_shape = [1] * latent_coords.ndim
155
+ broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
156
+ scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
157
+
158
+ # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
159
+ pixel_coords = latent_coords * scale_tensor
160
+
161
+ if causal_fix:
162
+ # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
163
+ # Shift and clamp to keep the first-frame timestamps causal and non-negative.
164
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
165
+
166
+ return pixel_coords
167
+
168
+
169
+ class AudioPatchifier(Patchifier):
170
+ def __init__(
171
+ self,
172
+ patch_size: int,
173
+ sample_rate: int = 16000,
174
+ hop_length: int = 160,
175
+ audio_latent_downsample_factor: int = 4,
176
+ is_causal: bool = True,
177
+ shift: int = 0,
178
+ ):
179
+ """
180
+ Patchifier tailored for spectrogram/audio latents.
181
+ Args:
182
+ patch_size: Number of mel bins combined into a single patch. This
183
+ controls the resolution along the frequency axis.
184
+ sample_rate: Original waveform sampling rate. Used to map latent
185
+ indices back to seconds so downstream consumers can align audio
186
+ and video cues.
187
+ hop_length: Window hop length used for the spectrogram. Determines
188
+ how many real-time samples separate two consecutive latent frames.
189
+ audio_latent_downsample_factor: Ratio between spectrogram frames and
190
+ latent frames; compensates for additional downsampling inside the
191
+ VAE encoder.
192
+ is_causal: When True, timing is shifted to account for causal
193
+ receptive fields so timestamps do not peek into the future.
194
+ shift: Integer offset applied to the latent indices. Enables
195
+ constructing overlapping windows from the same latent sequence.
196
+ """
197
+ self.hop_length = hop_length
198
+ self.sample_rate = sample_rate
199
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
200
+ self.is_causal = is_causal
201
+ self.shift = shift
202
+ self._patch_size = (1, patch_size, patch_size)
203
+
204
+ @property
205
+ def patch_size(self) -> Tuple[int, int, int]:
206
+ return self._patch_size
207
+
208
+ def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
209
+ return tgt_shape.frames
210
+
211
+ def _get_audio_latent_time_in_sec(
212
+ self,
213
+ start_latent: int,
214
+ end_latent: int,
215
+ dtype: torch.dtype,
216
+ device: Optional[torch.device] = None,
217
+ ) -> torch.Tensor:
218
+ """
219
+ Converts latent indices into real-time seconds while honoring causal
220
+ offsets and the configured hop length.
221
+ Args:
222
+ start_latent: Inclusive start index inside the latent sequence. This
223
+ sets the first timestamp returned.
224
+ end_latent: Exclusive end index. Determines how many timestamps get
225
+ generated.
226
+ dtype: Floating-point dtype used for the returned tensor, allowing
227
+ callers to control precision.
228
+ device: Target device for the timestamp tensor. When omitted the
229
+ computation occurs on CPU to avoid surprising GPU allocations.
230
+ """
231
+ if device is None:
232
+ device = torch.device("cpu")
233
+
234
+ audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
235
+
236
+ audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
237
+
238
+ if self.is_causal:
239
+ # Frame offset for causal alignment.
240
+ # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
241
+ causal_offset = 1
242
+ audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
243
+
244
+ return audio_mel_frame * self.hop_length / self.sample_rate
245
+
246
+ def _compute_audio_timings(
247
+ self,
248
+ batch_size: int,
249
+ num_steps: int,
250
+ device: Optional[torch.device] = None,
251
+ ) -> torch.Tensor:
252
+ """
253
+ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
254
+ This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
255
+ Args:
256
+ batch_size: Number of sequences to broadcast the timings over.
257
+ num_steps: Number of latent frames (time steps) to convert into timestamps.
258
+ device: Device on which the resulting tensor should reside.
259
+ """
260
+ resolved_device = device
261
+ if resolved_device is None:
262
+ resolved_device = torch.device("cpu")
263
+
264
+ start_timings = self._get_audio_latent_time_in_sec(
265
+ self.shift,
266
+ num_steps + self.shift,
267
+ torch.float32,
268
+ resolved_device,
269
+ )
270
+ start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
271
+
272
+ end_timings = self._get_audio_latent_time_in_sec(
273
+ self.shift + 1,
274
+ num_steps + self.shift + 1,
275
+ torch.float32,
276
+ resolved_device,
277
+ )
278
+ end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
279
+
280
+ return torch.stack([start_timings, end_timings], dim=-1)
281
+
282
+ def patchify(
283
+ self,
284
+ audio_latents: torch.Tensor,
285
+ ) -> torch.Tensor:
286
+ """
287
+ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
288
+ to derive timestamps for each latent frame based on the configured hop
289
+ length and downsampling.
290
+ Args:
291
+ audio_latents: Latent tensor to patchify.
292
+ Returns:
293
+ Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
294
+ corresponding timing metadata when needed.
295
+ """
296
+ audio_latents = einops.rearrange(
297
+ audio_latents,
298
+ "b c t f -> b t (c f)",
299
+ )
300
+
301
+ return audio_latents
302
+
303
+ def unpatchify(
304
+ self,
305
+ audio_latents: torch.Tensor,
306
+ output_shape: AudioLatentShape,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
310
+ Use `get_patch_grid_bounds` to recompute the timestamps that describe each
311
+ frame's position in real time.
312
+ Args:
313
+ audio_latents: Latent tensor to unpatchify.
314
+ output_shape: Shape of the unpatched output tensor.
315
+ Returns:
316
+ Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
317
+ metadata associated with the restored latents.
318
+ """
319
+ # audio_latents shape: (batch, time, freq * channels)
320
+ audio_latents = einops.rearrange(
321
+ audio_latents,
322
+ "b t (c f) -> b c t f",
323
+ c=output_shape.channels,
324
+ f=output_shape.mel_bins,
325
+ )
326
+
327
+ return audio_latents
328
+
329
+ def get_patch_grid_bounds(
330
+ self,
331
+ output_shape: AudioLatentShape | VideoLatentShape,
332
+ device: Optional[torch.device] = None,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Return the temporal bounds `[inclusive start, exclusive end)` for every
336
+ patch emitted by `patchify`. For audio this corresponds to timestamps in
337
+ seconds aligned with the original spectrogram grid.
338
+ The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
339
+ - axis 1 (size 1) represents the temporal dimension
340
+ - axis 3 (size 2) stores the `[start, end)` timestamps per patch
341
+ Args:
342
+ output_shape: Audio grid specification describing the number of time steps.
343
+ device: Target device for the returned tensor.
344
+ """
345
+ if not isinstance(output_shape, AudioLatentShape):
346
+ raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
347
+
348
+ return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
ltx2/ltx_core/components/protocols.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.types import AudioLatentShape, VideoLatentShape
6
+
7
+
8
+ class Patchifier(Protocol):
9
+ """
10
+ Protocol for patchifiers that convert latent tensors into patches and assemble them back.
11
+ """
12
+
13
+ def patchify(
14
+ self,
15
+ latents: torch.Tensor,
16
+ ) -> torch.Tensor:
17
+ ...
18
+ """
19
+ Convert latent tensors into flattened patch tokens.
20
+ Args:
21
+ latents: Latent tensor to patchify.
22
+ Returns:
23
+ Flattened patch tokens tensor.
24
+ """
25
+
26
+ def unpatchify(
27
+ self,
28
+ latents: torch.Tensor,
29
+ output_shape: AudioLatentShape | VideoLatentShape,
30
+ ) -> torch.Tensor:
31
+ """
32
+ Converts latent tensors between spatio-temporal formats and flattened sequence representations.
33
+ Args:
34
+ latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
35
+ output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
36
+ VideoLatentShape.
37
+ Returns:
38
+ Dense latent tensor restored from the flattened representation.
39
+ """
40
+
41
+ @property
42
+ def patch_size(self) -> Tuple[int, int, int]:
43
+ ...
44
+ """
45
+ Returns the patch size as a tuple of (temporal, height, width) dimensions
46
+ """
47
+
48
+ def get_patch_grid_bounds(
49
+ self,
50
+ output_shape: AudioLatentShape | VideoLatentShape,
51
+ device: torch.device | None = None,
52
+ ) -> torch.Tensor:
53
+ ...
54
+ """
55
+ Compute metadata describing where each latent patch resides within the
56
+ grid specified by `output_shape`.
57
+ Args:
58
+ output_shape: Target grid layout for the patches.
59
+ device: Target device for the returned tensor.
60
+ Returns:
61
+ Tensor containing patch coordinate metadata such as spatial or temporal intervals.
62
+ """
63
+
64
+
65
+ class SchedulerProtocol(Protocol):
66
+ """
67
+ Protocol for schedulers that provide a sigmas schedule tensor for a
68
+ given number of steps. Device is cpu.
69
+ """
70
+
71
+ def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
72
+
73
+
74
+ class GuiderProtocol(Protocol):
75
+ """
76
+ Protocol for guiders that compute a delta tensor given conditioning inputs.
77
+ The returned delta should be added to the conditional output (cond), enabling
78
+ multiple guiders to be chained together by accumulating their deltas.
79
+ """
80
+
81
+ scale: float
82
+
83
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
84
+
85
+ def enabled(self) -> bool:
86
+ """
87
+ Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
88
+ is 1.0.
89
+ """
90
+ ...
91
+
92
+
93
+ class DiffusionStepProtocol(Protocol):
94
+ """
95
+ Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
96
+ current denoised sample tensor, and sigmas tensor.
97
+ """
98
+
99
+ def step(
100
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs
101
+ ) -> torch.Tensor: ...
ltx2/ltx_core/components/schedulers.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import lru_cache
3
+
4
+ import numpy
5
+ import scipy
6
+ import torch
7
+
8
+ from ltx_core.components.protocols import SchedulerProtocol
9
+
10
+ BASE_SHIFT_ANCHOR = 1024
11
+ MAX_SHIFT_ANCHOR = 4096
12
+
13
+
14
+ class LTX2Scheduler(SchedulerProtocol):
15
+ """
16
+ Default scheduler for LTX-2 diffusion sampling.
17
+ Generates a sigma schedule with token-count-dependent shifting and optional
18
+ stretching to a terminal value.
19
+ """
20
+
21
+ def execute(
22
+ self,
23
+ steps: int,
24
+ latent: torch.Tensor | None = None,
25
+ max_shift: float = 2.05,
26
+ base_shift: float = 0.95,
27
+ stretch: bool = True,
28
+ terminal: float = 0.1,
29
+ default_number_of_tokens: int = MAX_SHIFT_ANCHOR,
30
+ **_kwargs,
31
+ ) -> torch.FloatTensor:
32
+ tokens = math.prod(latent.shape[2:]) if latent is not None else default_number_of_tokens
33
+ sigmas = torch.linspace(1.0, 0.0, steps + 1)
34
+
35
+ x1 = BASE_SHIFT_ANCHOR
36
+ x2 = MAX_SHIFT_ANCHOR
37
+ mm = (max_shift - base_shift) / (x2 - x1)
38
+ b = base_shift - mm * x1
39
+ sigma_shift = (tokens) * mm + b
40
+
41
+ power = 1
42
+ sigmas = torch.where(
43
+ sigmas != 0,
44
+ math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
45
+ 0,
46
+ )
47
+
48
+ # Stretch sigmas so that its final value matches the given terminal value.
49
+ if stretch:
50
+ non_zero_mask = sigmas != 0
51
+ non_zero_sigmas = sigmas[non_zero_mask]
52
+ one_minus_z = 1.0 - non_zero_sigmas
53
+ scale_factor = one_minus_z[-1] / (1.0 - terminal)
54
+ stretched = 1.0 - (one_minus_z / scale_factor)
55
+ sigmas[non_zero_mask] = stretched
56
+
57
+ return sigmas.to(torch.float32)
58
+
59
+
60
+ class LinearQuadraticScheduler(SchedulerProtocol):
61
+ """
62
+ Scheduler with linear steps followed by quadratic steps.
63
+ Produces a sigma schedule that transitions linearly up to a threshold,
64
+ then follows a quadratic curve for the remaining steps.
65
+ """
66
+
67
+ def execute(
68
+ self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
69
+ ) -> torch.FloatTensor:
70
+ if steps == 1:
71
+ return torch.FloatTensor([1.0, 0.0])
72
+
73
+ if linear_steps is None:
74
+ linear_steps = steps // 2
75
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
76
+ threshold_noise_step_diff = linear_steps - threshold_noise * steps
77
+ quadratic_steps = steps - linear_steps
78
+ quadratic_sigma_schedule = []
79
+ if quadratic_steps > 0:
80
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
81
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
82
+ const = quadratic_coef * (linear_steps**2)
83
+ quadratic_sigma_schedule = [
84
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
85
+ ]
86
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
87
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
88
+ return torch.FloatTensor(sigma_schedule)
89
+
90
+
91
+ class BetaScheduler(SchedulerProtocol):
92
+ """
93
+ Scheduler using a beta distribution to sample timesteps.
94
+ Based on: https://arxiv.org/abs/2407.12173
95
+ """
96
+
97
+ shift = 2.37
98
+ timesteps_length = 10000
99
+
100
+ def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
101
+ """
102
+ Execute the beta scheduler.
103
+ Args:
104
+ steps: The number of steps to execute the scheduler for.
105
+ alpha: The alpha parameter for the beta distribution.
106
+ beta: The beta parameter for the beta distribution.
107
+ Warnings:
108
+ The number of steps within `sigmas` theoretically might be less than `steps+1`,
109
+ because of the deduplication of the identical timesteps
110
+ Returns:
111
+ A tensor of sigmas.
112
+ """
113
+ model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
114
+ total_timesteps = len(model_sampling_sigmas) - 1
115
+ ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
116
+ ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
117
+ ts = list(dict.fromkeys(ts))
118
+
119
+ sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
120
+ return torch.FloatTensor(sigmas)
121
+
122
+
123
+ @lru_cache(maxsize=5)
124
+ def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
125
+ timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
126
+ return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
127
+
128
+
129
+ def flux_time_shift(mu: float, sigma: float, t: float) -> float:
130
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
ltx2/ltx_core/conditioning/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning utilities: latent state, tools, and conditioning types."""
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.conditioning.types import (
6
+ ConditioningItemAttentionStrengthWrapper,
7
+ VideoConditionByKeyframeIndex,
8
+ VideoConditionByLatentIndex,
9
+ VideoConditionByReferenceLatent,
10
+ )
11
+
12
+ __all__ = [
13
+ "ConditioningError",
14
+ "ConditioningItem",
15
+ "ConditioningItemAttentionStrengthWrapper",
16
+ "VideoConditionByKeyframeIndex",
17
+ "VideoConditionByLatentIndex",
18
+ "VideoConditionByReferenceLatent",
19
+ ]
ltx2/ltx_core/conditioning/exceptions.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class ConditioningError(Exception):
2
+ """
3
+ Class for conditioning-related errors.
4
+ """
ltx2/ltx_core/conditioning/item.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol
2
+
3
+ from ltx_core.tools import LatentTools
4
+ from ltx_core.types import LatentState
5
+
6
+
7
+ class ConditioningItem(Protocol):
8
+ """Protocol for conditioning items that modify latent state during diffusion."""
9
+
10
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
11
+ """
12
+ Apply the conditioning to the latent state.
13
+ Args:
14
+ latent_state: The latent state to apply the conditioning to. This is state always patchified.
15
+ Returns:
16
+ The latent state after the conditioning has been applied.
17
+ IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
18
+ latent.
19
+ """
20
+ ...
ltx2/ltx_core/conditioning/mask_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for building 2D self-attention masks for conditioning items."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from ltx_core.types import LatentState
11
+
12
+
13
+ def resolve_cross_mask(
14
+ attention_mask: float | int | torch.Tensor,
15
+ num_new_tokens: int,
16
+ batch_size: int,
17
+ device: torch.device,
18
+ dtype: torch.dtype,
19
+ ) -> torch.Tensor:
20
+ """Convert an attention_mask (scalar or tensor) to a (B, M) cross_mask tensor.
21
+ Args:
22
+ attention_mask: Scalar value applied uniformly, 1D tensor of shape (M,)
23
+ broadcast across batch, or 2D tensor of shape (B, M).
24
+ num_new_tokens: Number of new conditioning tokens M.
25
+ batch_size: Batch size B.
26
+ device: Device for the output tensor.
27
+ dtype: Data type for the output tensor.
28
+ Returns:
29
+ Cross-mask tensor of shape (B, M).
30
+ """
31
+ if isinstance(attention_mask, (int, float)):
32
+ return torch.full(
33
+ (batch_size, num_new_tokens),
34
+ fill_value=float(attention_mask),
35
+ device=device,
36
+ dtype=dtype,
37
+ )
38
+ mask = attention_mask.to(device=device, dtype=dtype)
39
+
40
+ # Handle scalar (0-D) tensor like a Python scalar.
41
+ if mask.dim() == 0:
42
+ return torch.full(
43
+ (batch_size, num_new_tokens),
44
+ fill_value=float(mask.item()),
45
+ device=device,
46
+ dtype=dtype,
47
+ )
48
+
49
+ if mask.dim() == 1:
50
+ if mask.shape[0] != num_new_tokens:
51
+ raise ValueError(
52
+ f"1-D attention_mask length must equal num_new_tokens ({num_new_tokens}), got shape {tuple(mask.shape)}"
53
+ )
54
+ mask = mask.unsqueeze(0).expand(batch_size, -1)
55
+ elif mask.dim() == 2:
56
+ b, m = mask.shape
57
+ if m != num_new_tokens:
58
+ raise ValueError(
59
+ f"2-D attention_mask second dimension must equal num_new_tokens ({num_new_tokens}), "
60
+ f"got shape {tuple(mask.shape)}"
61
+ )
62
+ if b not in (batch_size, 1):
63
+ raise ValueError(
64
+ f"2-D attention_mask batch dimension must equal batch_size ({batch_size}) or 1, "
65
+ f"got shape {tuple(mask.shape)}"
66
+ )
67
+ if b == 1 and batch_size > 1:
68
+ mask = mask.expand(batch_size, -1)
69
+ else:
70
+ raise ValueError(
71
+ f"attention_mask tensor must be 0-D, 1-D, or 2-D, got {mask.dim()}-D with shape {tuple(mask.shape)}"
72
+ )
73
+ return mask
74
+
75
+
76
+ def update_attention_mask(
77
+ latent_state: LatentState,
78
+ attention_mask: float | torch.Tensor | None,
79
+ num_noisy_tokens: int,
80
+ num_new_tokens: int,
81
+ batch_size: int,
82
+ device: torch.device,
83
+ dtype: torch.dtype,
84
+ ) -> torch.Tensor | None:
85
+ """Build or update the self-attention mask for newly appended conditioning tokens.
86
+ If *attention_mask* is ``None`` and no existing mask is present, returns
87
+ ``None``. If *attention_mask* is ``None`` but an existing mask is present,
88
+ the mask is expanded with full attention (1s) for the new tokens so that
89
+ its dimensions stay consistent with the growing latent sequence. Otherwise,
90
+ resolves *attention_mask* to a per-token cross-mask and expands the 2-D
91
+ attention mask via :func:`build_attention_mask`.
92
+ Args:
93
+ latent_state: Current latent state (provides the existing mask and total
94
+ existing-token count).
95
+ attention_mask: Per-token attention weight. Scalar, 1-D ``(M,)``, 2-D
96
+ ``(B, M)`` tensor, or ``None`` (no-op).
97
+ num_noisy_tokens: Number of original noisy tokens (from
98
+ ``latent_tools.target_shape.token_count()``).
99
+ num_new_tokens: Number of new conditioning tokens being appended.
100
+ batch_size: Batch size.
101
+ device: Device for the output tensor.
102
+ dtype: Data type for the output tensor.
103
+ Returns:
104
+ Updated attention mask of shape ``(B, N+M, N+M)``, or ``None`` if no
105
+ masking is needed.
106
+ """
107
+ if attention_mask is None:
108
+ if latent_state.attention_mask is None:
109
+ return None
110
+ # Existing mask present but no new mask requested: pad with 1s (full
111
+ # attention) so the mask dimensions stay consistent with the growing
112
+ # latent sequence.
113
+ cross_mask = torch.ones(batch_size, num_new_tokens, device=device, dtype=dtype)
114
+ return build_attention_mask(
115
+ existing_mask=latent_state.attention_mask,
116
+ num_noisy_tokens=num_noisy_tokens,
117
+ num_new_tokens=num_new_tokens,
118
+ num_existing_tokens=latent_state.latent.shape[1],
119
+ cross_mask=cross_mask,
120
+ device=device,
121
+ dtype=dtype,
122
+ )
123
+
124
+ cross_mask = resolve_cross_mask(attention_mask, num_new_tokens, batch_size, device, dtype)
125
+ return build_attention_mask(
126
+ existing_mask=latent_state.attention_mask,
127
+ num_noisy_tokens=num_noisy_tokens,
128
+ num_new_tokens=num_new_tokens,
129
+ num_existing_tokens=latent_state.latent.shape[1],
130
+ cross_mask=cross_mask,
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+
135
+
136
+ def build_attention_mask(
137
+ existing_mask: torch.Tensor | None,
138
+ num_noisy_tokens: int,
139
+ num_new_tokens: int,
140
+ num_existing_tokens: int,
141
+ cross_mask: torch.Tensor,
142
+ device: torch.device,
143
+ dtype: torch.dtype,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Expand the attention mask to include newly appended conditioning tokens.
147
+ Each conditioning item appends M new reference tokens to the sequence. This function
148
+ builds a (B, N+M, N+M) attention mask with the following block structure:
149
+ noisy prev_ref new_ref
150
+ (N_noisy) (N-N_noisy) (M)
151
+ ┌───────────┬───────────┬───────────┐
152
+ noisy │ │ │ │
153
+ (N_noisy) │ existing │ existing │ cross │
154
+ │ │ │ │
155
+ ├───────────┼───────────┼───────────┤
156
+ prev_ref │ │ │ │
157
+ (N-N_noisy)│ existing │ existing │ 0 │
158
+ │ │ │ │
159
+ ├───────────┼───────────┼───────────┤
160
+ new_ref │ │ │ │
161
+ (M) │ cross │ 0 │ 1 │
162
+ │ │ │ │
163
+ └───────────┴───────────┴───────────┘
164
+ Where:
165
+ - **existing**: preserved from the previous mask (or 1.0 if first conditioning)
166
+ - **cross**: values from *cross_mask* (shape B, M), in [0, 1]
167
+ - **0**: no attention between different reference groups
168
+ Args:
169
+ existing_mask: Current attention mask of shape (B, N, N), or None if no mask exists yet.
170
+ When None, the top-left NxN block is filled with 1s (full attention between all
171
+ existing tokens including any prior reference tokens that had no mask).
172
+ num_noisy_tokens: Number of original noisy tokens (always at positions [0:num_noisy_tokens]).
173
+ num_new_tokens: Number of new conditioning tokens M being appended.
174
+ num_existing_tokens: Total number of current tokens N (noisy + any prior conditioning tokens).
175
+ cross_mask: Per-token attention weight of shape (B, M) controlling attention between
176
+ new reference tokens and noisy tokens. Values in [0, 1].
177
+ device: Device for the output tensor.
178
+ dtype: Data type for the output tensor.
179
+ Returns:
180
+ Attention mask of shape (B, N+M, N+M) with values in [0, 1].
181
+ """
182
+ batch_size = cross_mask.shape[0]
183
+ total = num_existing_tokens + num_new_tokens
184
+
185
+ # Start with zeros
186
+ mask = torch.zeros((batch_size, total, total), device=device, dtype=dtype)
187
+
188
+ # Top-left: preserve existing mask or fill with 1s for noisy tokens
189
+ if existing_mask is not None:
190
+ mask[:, :num_existing_tokens, :num_existing_tokens] = existing_mask
191
+ else:
192
+ mask[:, :num_existing_tokens, :num_existing_tokens] = 1.0
193
+
194
+ # Bottom-right: new reference tokens fully attend to themselves
195
+ mask[:, num_existing_tokens:, num_existing_tokens:] = 1.0
196
+
197
+ # Cross-attention between noisy tokens and new reference tokens
198
+ # cross_mask shape: (B, M) -> broadcast to (B, N_noisy, M) and (B, M, N_noisy)
199
+
200
+ # Noisy tokens attending to new reference tokens: [0:N_noisy, N:N+M]
201
+ # Each column j in this block gets cross_mask[:, j]
202
+ mask[:, :num_noisy_tokens, num_existing_tokens:] = cross_mask.unsqueeze(1)
203
+
204
+ # New reference tokens attending to noisy tokens: [N:N+M, 0:N_noisy]
205
+ # Each row i in this block gets cross_mask[:, i]
206
+ mask[:, num_existing_tokens:, :num_noisy_tokens] = cross_mask.unsqueeze(2)
207
+
208
+ # [N_noisy:N, N:N+M] and [N:N+M, N_noisy:N] remain 0 (no cross-ref attention)
209
+
210
+ return mask
ltx2/ltx_core/conditioning/types/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning type implementations."""
2
+
3
+ from ltx_core.conditioning.types.attention_strength_wrapper import ConditioningItemAttentionStrengthWrapper
4
+ from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
5
+ from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
6
+ from ltx_core.conditioning.types.reference_video_cond import VideoConditionByReferenceLatent
7
+
8
+ __all__ = [
9
+ "ConditioningItemAttentionStrengthWrapper",
10
+ "VideoConditionByKeyframeIndex",
11
+ "VideoConditionByLatentIndex",
12
+ "VideoConditionByReferenceLatent",
13
+ ]
ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper conditioning item that adds attention masking to any inner conditioning."""
2
+
3
+ from dataclasses import replace
4
+
5
+ import torch
6
+
7
+ from ltx_core.conditioning.item import ConditioningItem
8
+ from ltx_core.conditioning.mask_utils import update_attention_mask
9
+ from ltx_core.tools import LatentTools
10
+ from ltx_core.types import LatentState
11
+
12
+
13
+ class ConditioningItemAttentionStrengthWrapper(ConditioningItem):
14
+ """Wraps a conditioning item to add an attention mask for its tokens.
15
+ Separates the *attention-masking* concern from the underlying conditioning
16
+ logic (token layout, positional encoding, denoise strength). The inner
17
+ conditioning item appends tokens to the latent sequence as usual, and this
18
+ wrapper then builds or updates the self-attention mask so that the newly
19
+ added tokens interact with the noisy tokens according to *attention_mask*.
20
+ Args:
21
+ conditioning: Any conditioning item that appends tokens to the latent.
22
+ attention_mask: Per-token attention weight controlling how strongly the
23
+ new conditioning tokens attend to/from noisy tokens. Can be a
24
+ scalar (float) applied uniformly, or a tensor of shape ``(B, M)``
25
+ for spatial control, where ``M = F * H * W`` is the number of
26
+ patchified conditioning tokens. Values in ``[0, 1]``.
27
+ Example::
28
+ cond = ConditioningItemAttentionStrengthWrapper(
29
+ VideoConditionByReferenceLatent(latent=ref, strength=1.0),
30
+ attention_mask=0.5,
31
+ )
32
+ state = cond.apply_to(latent_state, latent_tools)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ conditioning: ConditioningItem,
38
+ attention_mask: float | torch.Tensor,
39
+ ):
40
+ self.conditioning = conditioning
41
+ self.attention_mask = attention_mask
42
+
43
+ def apply_to(
44
+ self,
45
+ latent_state: LatentState,
46
+ latent_tools: LatentTools,
47
+ ) -> LatentState:
48
+ """Apply inner conditioning, then build the attention mask for its tokens."""
49
+ # Snapshot the original state for mask building
50
+ original_state = latent_state
51
+
52
+ # Inner conditioning appends tokens (positions, denoise mask, etc.)
53
+ new_state = self.conditioning.apply_to(latent_state, latent_tools)
54
+
55
+ num_new_tokens = new_state.latent.shape[1] - original_state.latent.shape[1]
56
+ if num_new_tokens == 0:
57
+ return new_state
58
+
59
+ # Build the attention mask using the *original* state as the reference
60
+ # so that the block structure is computed correctly.
61
+ new_attention_mask = update_attention_mask(
62
+ latent_state=original_state,
63
+ attention_mask=self.attention_mask,
64
+ num_noisy_tokens=latent_tools.target_shape.token_count(),
65
+ num_new_tokens=num_new_tokens,
66
+ batch_size=new_state.latent.shape[0],
67
+ device=new_state.latent.device,
68
+ dtype=new_state.latent.dtype,
69
+ )
70
+
71
+ return replace(new_state, attention_mask=new_attention_mask)
ltx2/ltx_core/conditioning/types/keyframe_cond.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.patchifiers import get_pixel_coords
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.conditioning.mask_utils import update_attention_mask
6
+ from ltx_core.tools import VideoLatentTools
7
+ from ltx_core.types import LatentState, VideoLatentShape
8
+
9
+
10
+ class VideoConditionByKeyframeIndex(ConditioningItem):
11
+ """
12
+ Conditions video generation on keyframe latents at a specific frame index.
13
+ Appends keyframe tokens to the latent state with positions offset by frame_idx,
14
+ and sets denoise strength according to the strength parameter.
15
+ To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
16
+ Args:
17
+ keyframes: Keyframe latents [B, C, F, H, W].
18
+ frame_idx: Frame index offset for positional encoding.
19
+ strength: Conditioning strength (1.0 = clean, 0.0 = fully denoised).
20
+ """
21
+
22
+ def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
23
+ self.keyframes = keyframes
24
+ self.frame_idx = frame_idx
25
+ self.strength = strength
26
+
27
+ def apply_to(
28
+ self,
29
+ latent_state: LatentState,
30
+ latent_tools: VideoLatentTools,
31
+ ) -> LatentState:
32
+ tokens = latent_tools.patchifier.patchify(self.keyframes)
33
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
34
+ output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
35
+ device=self.keyframes.device,
36
+ )
37
+ positions = get_pixel_coords(
38
+ latent_coords=latent_coords,
39
+ scale_factors=latent_tools.scale_factors,
40
+ causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
41
+ )
42
+
43
+ positions[:, 0, ...] += self.frame_idx
44
+ positions = positions.to(dtype=torch.float32)
45
+ positions[:, 0, ...] /= latent_tools.fps
46
+
47
+ denoise_mask = torch.full(
48
+ size=(*tokens.shape[:2], 1),
49
+ fill_value=1.0 - self.strength,
50
+ device=self.keyframes.device,
51
+ dtype=self.keyframes.dtype,
52
+ )
53
+
54
+ new_attention_mask = update_attention_mask(
55
+ latent_state=latent_state,
56
+ attention_mask=None,
57
+ num_noisy_tokens=latent_tools.target_shape.token_count(),
58
+ num_new_tokens=tokens.shape[1],
59
+ batch_size=tokens.shape[0],
60
+ device=self.keyframes.device,
61
+ dtype=self.keyframes.dtype,
62
+ )
63
+
64
+ return LatentState(
65
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
66
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
67
+ positions=torch.cat([latent_state.positions, positions], dim=2),
68
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
69
+ attention_mask=new_attention_mask,
70
+ )
ltx2/ltx_core/conditioning/types/latent_cond.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import LatentTools
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class VideoConditionByLatentIndex(ConditioningItem):
10
+ """
11
+ Conditions video generation by injecting latents at a specific latent frame index.
12
+ Replaces tokens in the latent state at positions corresponding to latent_idx,
13
+ and sets denoise strength according to the strength parameter.
14
+ """
15
+
16
+ def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
17
+ self.latent = latent
18
+ self.strength = strength
19
+ self.latent_idx = latent_idx
20
+
21
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
22
+ cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
23
+ tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
24
+
25
+ if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
26
+ raise ConditioningError(
27
+ f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
28
+ f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
29
+ "the image and latent have the same spatial shape."
30
+ )
31
+
32
+ tokens = latent_tools.patchifier.patchify(self.latent)
33
+ start_token = latent_tools.patchifier.get_token_count(
34
+ latent_tools.target_shape._replace(frames=self.latent_idx)
35
+ )
36
+ stop_token = start_token + tokens.shape[1]
37
+
38
+ latent_state = latent_state.clone()
39
+
40
+ latent_state.latent[:, start_token:stop_token] = tokens
41
+ latent_state.clean_latent[:, start_token:stop_token] = tokens
42
+ latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
43
+
44
+ return latent_state
ltx2/ltx_core/conditioning/types/noise_mask_cond.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from ltx_core.components.patchifiers import get_pixel_coords
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.tools import LatentTools, SpatioTemporalScaleFactors
6
+ from ltx_core.types import AudioLatentShape, LatentState, VideoLatentShape
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class TemporalRegionMask(ConditioningItem):
11
+ """Conditioning item that sets ``denoise_mask = 0`` outside a time range
12
+ and ``1`` inside, so only the specified temporal region is regenerated.
13
+ Uses ``start_time`` and ``end_time`` in seconds. Works in *patchified*
14
+ (token) space using the patchifier's ``get_patch_grid_bounds``: for video
15
+ coords are latent frame indices (converted from seconds via ``fps``), for
16
+ audio coords are already in seconds.
17
+ """
18
+
19
+ start_time: float # seconds, inclusive
20
+ end_time: float # seconds, exclusive
21
+ fps: float
22
+
23
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
24
+ coords = latent_tools.patchifier.get_patch_grid_bounds(
25
+ latent_tools.target_shape, device=latent_state.denoise_mask.device
26
+ )
27
+ if isinstance(latent_tools.target_shape, AudioLatentShape):
28
+ # Audio: patchifier get_patch_grid_bounds returns seconds
29
+ t_boundaries = coords[:, 0]
30
+ elif isinstance(latent_tools.target_shape, VideoLatentShape):
31
+ # Video: patchifier get_patch_grid_bounds returns latent bounds, converting to frame numbers & pixel bounds
32
+ scale_factors = getattr(latent_tools, "scale_factors", SpatioTemporalScaleFactors.default())
33
+ pixel_bounds = get_pixel_coords(coords, scale_factors, causal_fix=getattr(latent_tools, "causal_fix", True))
34
+ # converting frame numbers to seconds
35
+ t_boundaries = pixel_bounds[:, 0] / self.fps
36
+ else:
37
+ raise ValueError("Unsupported LatentShape type, expected AudioLatentShape or VideoLatentShape")
38
+ t_start, t_end = t_boundaries.unbind(dim=-1) # [B, N]
39
+ in_region = (t_end > self.start_time) & (t_start < self.end_time)
40
+ state = latent_state.clone()
41
+ mask_val = in_region.to(state.denoise_mask.dtype)
42
+ if state.denoise_mask.dim() == 3:
43
+ mask_val = mask_val.unsqueeze(-1)
44
+ state.denoise_mask.copy_(mask_val)
45
+ return state
ltx2/ltx_core/conditioning/types/reference_video_cond.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reference video conditioning for IC-LoRA inference."""
2
+
3
+ import torch
4
+
5
+ from ltx_core.components.patchifiers import get_pixel_coords
6
+ from ltx_core.conditioning.item import ConditioningItem
7
+ from ltx_core.conditioning.mask_utils import update_attention_mask
8
+ from ltx_core.tools import VideoLatentTools
9
+ from ltx_core.types import LatentState, VideoLatentShape
10
+
11
+
12
+ class VideoConditionByReferenceLatent(ConditioningItem):
13
+ """
14
+ Conditions video generation on a reference video latent for IC-LoRA inference.
15
+ IC-LoRAs are trained by concatenating reference (control signal) and target tokens,
16
+ learning to attend across both. This class replicates that setup at inference by
17
+ appending reference tokens to the latent sequence.
18
+ IC-LoRAs can be trained with lower-resolution references than the target (e.g., 384px
19
+ reference for 768px output) for efficiency and better generalization. The
20
+ `downscale_factor` scales reference positions to match target coordinates, preserving
21
+ the learned positional relationships. This must match the factor used during training
22
+ (stored in LoRA metadata).
23
+ To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
24
+ Args:
25
+ latent: Reference video latents [B, C, F, H, W]
26
+ downscale_factor: Target/reference resolution ratio (e.g., 2 = half-resolution
27
+ reference). Spatial positions are scaled by this factor.
28
+ strength: Conditioning strength. 1.0 = full (reference kept clean),
29
+ 0.0 = none (reference denoised). Default 1.0.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ latent: torch.Tensor,
35
+ downscale_factor: int = 1,
36
+ strength: float = 1.0,
37
+ ):
38
+ self.latent = latent
39
+ self.downscale_factor = downscale_factor
40
+ self.strength = strength
41
+
42
+ def apply_to(
43
+ self,
44
+ latent_state: LatentState,
45
+ latent_tools: VideoLatentTools,
46
+ ) -> LatentState:
47
+ """Append reference video tokens with scaled positions."""
48
+ tokens = latent_tools.patchifier.patchify(self.latent)
49
+
50
+ # Compute positions for the reference video's actual dimensions
51
+ latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
52
+ output_shape=VideoLatentShape.from_torch_shape(self.latent.shape),
53
+ device=self.latent.device,
54
+ )
55
+ positions = get_pixel_coords(
56
+ latent_coords=latent_coords,
57
+ scale_factors=latent_tools.scale_factors,
58
+ causal_fix=latent_tools.causal_fix,
59
+ )
60
+ positions = positions.to(dtype=torch.float32)
61
+ positions[:, 0, ...] /= latent_tools.fps
62
+
63
+ # Scale spatial positions to match target coordinate space
64
+ if self.downscale_factor != 1:
65
+ positions[:, 1, ...] *= self.downscale_factor # height axis
66
+ positions[:, 2, ...] *= self.downscale_factor # width axis
67
+
68
+ denoise_mask = torch.full(
69
+ size=(*tokens.shape[:2], 1),
70
+ fill_value=1.0 - self.strength,
71
+ device=self.latent.device,
72
+ dtype=self.latent.dtype,
73
+ )
74
+
75
+ new_attention_mask = update_attention_mask(
76
+ latent_state=latent_state,
77
+ attention_mask=None,
78
+ num_noisy_tokens=latent_tools.target_shape.token_count(),
79
+ num_new_tokens=tokens.shape[1],
80
+ batch_size=tokens.shape[0],
81
+ device=self.latent.device,
82
+ dtype=self.latent.dtype,
83
+ )
84
+
85
+ return LatentState(
86
+ latent=torch.cat([latent_state.latent, tokens], dim=1),
87
+ denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
88
+ positions=torch.cat([latent_state.positions, positions], dim=2),
89
+ clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
90
+ attention_mask=new_attention_mask,
91
+ )
ltx2/ltx_core/guidance/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Guidance and perturbation utilities for attention manipulation."""
2
+
3
+ from ltx_core.guidance.perturbations import (
4
+ BatchedPerturbationConfig,
5
+ Perturbation,
6
+ PerturbationConfig,
7
+ PerturbationType,
8
+ )
9
+
10
+ __all__ = [
11
+ "BatchedPerturbationConfig",
12
+ "Perturbation",
13
+ "PerturbationConfig",
14
+ "PerturbationType",
15
+ ]
ltx2/ltx_core/guidance/perturbations.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ import torch
5
+ from torch._prims_common import DeviceLikeType
6
+
7
+
8
+ class PerturbationType(Enum):
9
+ """Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
10
+
11
+ SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
12
+ SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
13
+ SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
14
+ SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Perturbation:
19
+ """A single perturbation specifying which attention type to skip and in which blocks."""
20
+
21
+ type: PerturbationType
22
+ blocks: list[int] | None # None means all blocks
23
+
24
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
25
+ if self.type != perturbation_type:
26
+ return False
27
+
28
+ if self.blocks is None:
29
+ return True
30
+
31
+ return block in self.blocks
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class PerturbationConfig:
36
+ """Configuration holding a list of perturbations for a single sample."""
37
+
38
+ perturbations: list[Perturbation] | None
39
+
40
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
41
+ if self.perturbations is None:
42
+ return False
43
+
44
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
45
+
46
+ @staticmethod
47
+ def empty() -> "PerturbationConfig":
48
+ return PerturbationConfig([])
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class BatchedPerturbationConfig:
53
+ """Perturbation configurations for a batch, with utilities for generating attention masks."""
54
+
55
+ perturbations: list[PerturbationConfig]
56
+
57
+ def mask(
58
+ self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
59
+ ) -> torch.Tensor:
60
+ mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
61
+ for batch_idx, perturbation in enumerate(self.perturbations):
62
+ if perturbation.is_perturbed(perturbation_type, block):
63
+ mask[batch_idx] = 0
64
+
65
+ return mask
66
+
67
+ def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
68
+ mask = self.mask(perturbation_type, block, values.device, values.dtype)
69
+ return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
70
+
71
+ def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
72
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
73
+
74
+ def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
75
+ return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
76
+
77
+ @staticmethod
78
+ def empty(batch_size: int) -> "BatchedPerturbationConfig":
79
+ return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
ltx2/ltx_core/layer_streaming.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Layer streaming wrapper for memory-efficient inference.
2
+ Keeps most transformer/decoder layers on CPU pinned memory and streams them
3
+ to GPU on demand, using a secondary CUDA stream to prefetch upcoming layers
4
+ so that data transfer overlaps with compute.
5
+ General-purpose: works with any ``nn.Module`` whose forward iterates over a
6
+ ``nn.ModuleList`` attribute (e.g. ``transformer_blocks``, ``layers``).
7
+ Each layer is evicted back to CPU immediately after its forward completes,
8
+ and prefetch uses modular indexing so the last layer's prefetch wraps around
9
+ to prepare early layers for the next forward pass.
10
+ Example
11
+ -------
12
+ >>> model = build_my_model(device=torch.device("cpu"))
13
+ >>> model = LayerStreamingWrapper(
14
+ ... model,
15
+ ... layers_attr="transformer_blocks",
16
+ ... target_device=torch.device("cuda:0"),
17
+ ... prefetch_count=2,
18
+ ... )
19
+ >>> out = model(inputs) # hooks handle layer streaming
20
+ >>> model.teardown() # move everything back to CPU
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import functools
26
+ import itertools
27
+ import logging
28
+ from typing import Any
29
+
30
+ import torch
31
+ from torch import nn
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def _resolve_attr(module: nn.Module, dotted_path: str) -> nn.ModuleList:
37
+ """Resolve a dotted attribute path like ``'model.language_model.layers'``."""
38
+ obj: Any = module
39
+ for part in dotted_path.split("."):
40
+ obj = getattr(obj, part)
41
+ if not isinstance(obj, nn.ModuleList):
42
+ raise TypeError(f"Expected nn.ModuleList at '{dotted_path}', got {type(obj).__name__}")
43
+ return obj
44
+
45
+
46
+ class _LayerStore:
47
+ """Manages on-demand pinning of layer parameters for GPU streaming.
48
+ Stores references to each layer's source data (which may be file-backed
49
+ mmap views or in-memory tensors). When a layer needs to be transferred
50
+ to GPU, its source data is pinned on demand and copied; on eviction the
51
+ pinned copy is freed and the source data is restored.
52
+ """
53
+
54
+ def __init__(self, layers: nn.ModuleList, target_device: torch.device) -> None:
55
+ self.target_device = target_device
56
+ self.num_layers = len(layers)
57
+ self._on_gpu: set[int] = set()
58
+
59
+ # Keep a reference to the source data for each layer so we can pin it
60
+ # on demand and restore it after eviction.
61
+ self._source_data: list[dict[str, torch.Tensor]] = []
62
+ for layer in layers:
63
+ source: dict[str, torch.Tensor] = {}
64
+ for name, tensor in itertools.chain(layer.named_parameters(), layer.named_buffers()):
65
+ source[name] = tensor.data
66
+ self._source_data.append(source)
67
+
68
+ # Hold pinned tensors alive until the H2D transfer completes.
69
+ # Without this, the CachingHostAllocator can reclaim a pinned tensor
70
+ # as soon as its Python reference is dropped, even if an async H2D
71
+ # transfer is still reading from it.
72
+ self._pinned_in_flight: dict[int, list[torch.Tensor]] = {}
73
+
74
+ def _check_idx(self, idx: int) -> None:
75
+ if idx < 0 or idx >= self.num_layers:
76
+ raise IndexError(f"Layer index {idx} out of range [0, {self.num_layers})")
77
+
78
+ def is_on_gpu(self, idx: int) -> bool:
79
+ return idx in self._on_gpu
80
+
81
+ def move_to_gpu(self, idx: int, layer: nn.Module, *, non_blocking: bool = False) -> None:
82
+ """Pin layer *idx* on demand, then transfer to GPU."""
83
+ self._check_idx(idx)
84
+ if idx in self._on_gpu:
85
+ return
86
+ source = self._source_data[idx]
87
+ pinned_refs: list[torch.Tensor] = []
88
+ for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()):
89
+ pinned = source[name].pin_memory()
90
+ param.data = pinned.to(self.target_device, non_blocking=non_blocking)
91
+ pinned_refs.append(pinned)
92
+ # Keep pinned tensors alive until eviction — the async H2D transfer
93
+ # may still be reading from them.
94
+ self._pinned_in_flight[idx] = pinned_refs
95
+ self._on_gpu.add(idx)
96
+
97
+ def evict_to_cpu(self, idx: int, layer: nn.Module) -> None:
98
+ """Restore source data, freeing the GPU and pinned copies."""
99
+ self._check_idx(idx)
100
+ if idx not in self._on_gpu:
101
+ return
102
+ source = self._source_data[idx]
103
+ for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()):
104
+ param.data = source[name]
105
+ # Release pinned tensors — the H2D transfer is complete by now
106
+ # (the compute stream waited on the prefetch event before using
107
+ # the layer, and we only evict after compute finishes).
108
+ self._pinned_in_flight.pop(idx, None)
109
+ self._on_gpu.discard(idx)
110
+
111
+ def cleanup(self) -> None:
112
+ """Release all source data and in-flight pinned references.
113
+ After this call, the source tensors can be garbage-collected once
114
+ the layer parameters (which still reference them via ``.data``) are
115
+ also released (e.g. via ``.to("meta")``).
116
+ """
117
+ for source_dict in self._source_data:
118
+ source_dict.clear()
119
+ self._source_data.clear()
120
+ self._pinned_in_flight.clear()
121
+
122
+
123
+ class _AsyncPrefetcher:
124
+ """Issues H2D transfers on a dedicated CUDA stream.
125
+ Uses per-layer CUDA events so that the compute stream only waits for the
126
+ specific layer it needs, not all pending transfers.
127
+ """
128
+
129
+ def __init__(self, store: _LayerStore, layers: nn.ModuleList) -> None:
130
+ self._store = store
131
+ self._layers = layers
132
+ self._stream = torch.cuda.Stream(device=store.target_device)
133
+ self._events: dict[int, torch.cuda.Event] = {}
134
+
135
+ def prefetch(self, idx: int) -> None:
136
+ """Begin async transfer of layer *idx* to GPU (no-op if already there)."""
137
+ if self._store.is_on_gpu(idx) or idx in self._events:
138
+ return
139
+ with torch.cuda.stream(self._stream):
140
+ self._store.move_to_gpu(idx, self._layers[idx], non_blocking=True)
141
+ event = torch.cuda.Event()
142
+ event.record(self._stream)
143
+ self._events[idx] = event
144
+
145
+ def wait(self, idx: int) -> None:
146
+ """Block the compute stream until layer *idx* transfer is complete."""
147
+ event = self._events.pop(idx, None)
148
+ if event is not None:
149
+ torch.cuda.current_stream(self._store.target_device).wait_event(event)
150
+
151
+ def cleanup(self) -> None:
152
+ """Drain pending work and release CUDA stream/event resources."""
153
+ self._events.clear()
154
+ self._stream = None
155
+ self._layers = None
156
+ self._store = None
157
+
158
+
159
+ class LayerStreamingWrapper(nn.Module):
160
+ """Wraps a model to stream its sequential layers between CPU and GPU.
161
+ Each layer is evicted immediately after its forward completes, and
162
+ prefetch wraps around using modular indexing so the end of one forward
163
+ pass prepares early layers for the next.
164
+ Parameters
165
+ ----------
166
+ model:
167
+ The model to wrap, with all parameters on **CPU**.
168
+ layers_attr:
169
+ Dotted attribute path to the ``nn.ModuleList`` of sequential layers
170
+ (e.g. ``"transformer_blocks"`` or ``"model.language_model.layers"``).
171
+ target_device:
172
+ The GPU device to use for compute.
173
+ prefetch_count:
174
+ How many layers ahead to prefetch. The maximum number of layers on
175
+ GPU at once is ``1 + prefetch_count``. Must be >= 1.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ model: nn.Module,
181
+ layers_attr: str,
182
+ target_device: torch.device,
183
+ prefetch_count: int = 2,
184
+ ) -> None:
185
+ if prefetch_count < 1:
186
+ raise ValueError("prefetch_count must be >= 1")
187
+ super().__init__()
188
+ # Store the wrapped model as a submodule so parameters are discoverable.
189
+ self._model = model
190
+ self._layers = _resolve_attr(model, layers_attr)
191
+ self._target_device = target_device
192
+ # Clamp: no point prefetching more than num_layers - 1 (the rest are evicted).
193
+ self._prefetch_count = min(prefetch_count, len(self._layers) - 1)
194
+ self._hooks: list[torch.utils.hooks.RemovableHandle] = []
195
+
196
+ self._setup()
197
+
198
+ # ------------------------------------------------------------------
199
+ # Setup / teardown
200
+ # ------------------------------------------------------------------
201
+
202
+ def _setup(self) -> None:
203
+ # 1. Build the pinned CPU store (copies all layer tensors to pinned memory).
204
+ self._store = _LayerStore(self._layers, self._target_device)
205
+
206
+ # 2. Move all NON-layer params/buffers to GPU.
207
+ layer_tensor_ids: set[int] = set()
208
+ for layer in self._layers:
209
+ for t in itertools.chain(layer.parameters(), layer.buffers()):
210
+ layer_tensor_ids.add(id(t))
211
+
212
+ for p in self._model.parameters():
213
+ if id(p) not in layer_tensor_ids:
214
+ p.data = p.data.to(self._target_device)
215
+ for b in self._model.buffers():
216
+ if id(b) not in layer_tensor_ids:
217
+ b.data = b.data.to(self._target_device)
218
+
219
+ # 3. Pre-load the first (1 + prefetch_count) layers synchronously.
220
+ for idx in range(min(self._prefetch_count + 1, len(self._layers))):
221
+ self._store.move_to_gpu(idx, self._layers[idx])
222
+
223
+ # 4. Create the async prefetcher and register hooks.
224
+ self._prefetcher = _AsyncPrefetcher(self._store, self._layers)
225
+ self._register_hooks()
226
+
227
+ def _register_hooks(self) -> None:
228
+ idx_map: dict[int, int] = {id(layer): idx for idx, layer in enumerate(self._layers)}
229
+ num_layers = len(self._layers)
230
+
231
+ compute_stream = torch.cuda.current_stream(self._target_device)
232
+
233
+ def _pre_hook(
234
+ module: nn.Module,
235
+ _args: Any, # noqa: ANN401
236
+ *,
237
+ idx: int,
238
+ ) -> None:
239
+ # Wait only for THIS layer's H2D transfer (not all pending ones).
240
+ self._prefetcher.wait(idx)
241
+ if not self._store.is_on_gpu(idx):
242
+ self._store.move_to_gpu(idx, module)
243
+
244
+ # Record that the compute stream will read these weight tensors.
245
+ # They were allocated on the prefetch stream, so without this the
246
+ # caching allocator would allow the prefetch stream to reuse their
247
+ # memory immediately after eviction — even if the compute kernel
248
+ # that reads them hasn't finished yet.
249
+ for param in itertools.chain(module.parameters(), module.buffers()):
250
+ param.data.record_stream(compute_stream)
251
+
252
+ # Kick off prefetch for upcoming layers (wraps around for next pass).
253
+ for offset in range(1, self._prefetch_count + 1):
254
+ self._prefetcher.prefetch((idx + offset) % num_layers)
255
+
256
+ def _post_hook(
257
+ module: nn.Module,
258
+ _args: Any, # noqa: ANN401
259
+ _output: Any, # noqa: ANN401
260
+ *,
261
+ idx: int,
262
+ ) -> None:
263
+ # Evict this layer immediately — its computation is done.
264
+ self._store.evict_to_cpu(idx, module)
265
+
266
+ for layer in self._layers:
267
+ idx = idx_map[id(layer)]
268
+ h1 = layer.register_forward_pre_hook(functools.partial(_pre_hook, idx=idx))
269
+ h2 = layer.register_forward_hook(functools.partial(_post_hook, idx=idx))
270
+ self._hooks.extend([h1, h2])
271
+
272
+ def teardown(self) -> None:
273
+ """Remove hooks, release resources, and move parameters back to CPU.
274
+ After this call the wrapper is inert: hooks are removed, the prefetch
275
+ stream is drained and destroyed, all parameters reside on CPU, and the
276
+ ``_LayerStore`` source data references are cleared. Callers should
277
+ still follow up with ``.to("meta")`` to release the CPU copies if the
278
+ model is no longer needed.
279
+ """
280
+ for h in self._hooks:
281
+ h.remove()
282
+ self._hooks.clear()
283
+
284
+ # Drain all in-flight async H2D copies, then release stream resources.
285
+ # Without the synchronize, clearing the stream/events can trigger
286
+ # use-after-free at the CUDA driver level.
287
+ torch.cuda.synchronize(device=self._target_device)
288
+ if self._prefetcher is not None:
289
+ self._prefetcher.cleanup()
290
+ self._prefetcher = None
291
+
292
+ # Move everything to CPU.
293
+ for idx, layer in enumerate(self._layers):
294
+ self._store.evict_to_cpu(idx, layer)
295
+
296
+ for p in self._model.parameters():
297
+ p.data = p.data.to("cpu")
298
+ for b in self._model.buffers():
299
+ b.data = b.data.to("cpu")
300
+
301
+ # Release source data references. After evict_to_cpu() the layer
302
+ # params point to the source data. The caller is expected to follow
303
+ # up with .to("meta") to drop the param refs; cleanup() drops the
304
+ # store's refs.
305
+ self._store.cleanup()
306
+
307
+ # ------------------------------------------------------------------
308
+ # Forward and attribute delegation
309
+ # ------------------------------------------------------------------
310
+
311
+ def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401
312
+ return self._model(*args, **kwargs)
313
+
314
+ def __getattr__(self, name: str) -> Any: # noqa: ANN401
315
+ """Proxy attribute access to the wrapped model.
316
+ This allows calling methods like ``encode()`` on a wrapped
317
+ GemmaTextEncoder without the caller needing to know about the wrapper.
318
+ ``nn.Module.__getattr__`` is only called when normal attribute lookup
319
+ fails, so ``_model``, ``_store``, etc. are found first via ``__dict__``.
320
+ """
321
+ try:
322
+ return super().__getattr__(name)
323
+ except AttributeError:
324
+ return getattr(self._model, name)
ltx2/ltx_core/loader/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader utilities for model weights, LoRAs, and safetensor operations."""
2
+
3
+ from ltx_core.loader.fuse_loras import apply_loras
4
+ from ltx_core.loader.module_ops import ModuleOps
5
+ from ltx_core.loader.primitives import (
6
+ LoRAAdaptableProtocol,
7
+ LoraPathStrengthAndSDOps,
8
+ LoraStateDictWithStrength,
9
+ ModelBuilderProtocol,
10
+ StateDict,
11
+ StateDictLoader,
12
+ )
13
+ from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
14
+ from ltx_core.loader.sd_ops import (
15
+ LTXV_LORA_COMFY_RENAMING_MAP,
16
+ ContentMatching,
17
+ ContentReplacement,
18
+ KeyValueOperation,
19
+ KeyValueOperationResult,
20
+ SDKeyValueOperation,
21
+ SDOps,
22
+ )
23
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
24
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
25
+
26
+ __all__ = [
27
+ "LTXV_LORA_COMFY_RENAMING_MAP",
28
+ "ContentMatching",
29
+ "ContentReplacement",
30
+ "DummyRegistry",
31
+ "KeyValueOperation",
32
+ "KeyValueOperationResult",
33
+ "LoRAAdaptableProtocol",
34
+ "LoraPathStrengthAndSDOps",
35
+ "LoraStateDictWithStrength",
36
+ "ModelBuilderProtocol",
37
+ "ModuleOps",
38
+ "Registry",
39
+ "SDKeyValueOperation",
40
+ "SDOps",
41
+ "SafetensorsModelStateDictLoader",
42
+ "SafetensorsStateDictLoader",
43
+ "SingleGPUModelBuilder",
44
+ "StateDict",
45
+ "StateDictLoader",
46
+ "StateDictRegistry",
47
+ "apply_loras",
48
+ ]
ltx2/ltx_core/loader/fuse_loras.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Iterator
2
+
3
+ import torch
4
+
5
+ from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
6
+ from ltx_core.quantization.fp8_cast import _fused_add_round_launch
7
+ from ltx_core.quantization.fp8_scaled_mm import quantize_weight_to_fp8_per_tensor
8
+
9
+
10
+ def _get_device() -> torch.device:
11
+ if torch.cuda.is_available():
12
+ return torch.device("cuda", torch.cuda.current_device())
13
+ return torch.device("cpu")
14
+
15
+
16
+ def fuse_lora_weights(
17
+ model_sd: StateDict,
18
+ lora_sd_and_strengths: list[LoraStateDictWithStrength],
19
+ dtype: torch.dtype | None = None,
20
+ ) -> Iterator[tuple[str, torch.Tensor]]:
21
+ """Yield ``(key, fused_tensor)`` for each weight modified by at least one LoRA.
22
+ For scaled-FP8 weights, this includes both the updated ``.weight`` tensor
23
+ and its corresponding ``.weight_scale`` tensor.
24
+ """
25
+ for key, original_weight in model_sd.sd.items():
26
+ if original_weight is None or key.endswith(".weight_scale"):
27
+ continue
28
+ original_device = original_weight.device
29
+ weight = original_weight.to(device=_get_device())
30
+ target_dtype = dtype if dtype is not None else weight.dtype
31
+ deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
32
+
33
+ deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, weight.device)
34
+ if deltas is None:
35
+ continue
36
+
37
+ scale_key = key.replace(".weight", ".weight_scale") if key.endswith(".weight") else None
38
+ is_scaled_fp8 = scale_key is not None and scale_key in model_sd.sd
39
+
40
+ if weight.dtype == torch.float8_e4m3fn:
41
+ if is_scaled_fp8:
42
+ fused = _fuse_delta_with_scaled_fp8(deltas, weight, key, scale_key, model_sd)
43
+ else:
44
+ fused = _fuse_delta_with_cast_fp8(deltas, weight, key, target_dtype)
45
+ elif weight.dtype == torch.bfloat16:
46
+ fused = _fuse_delta_with_bfloat16(deltas, weight, key, target_dtype)
47
+ else:
48
+ raise ValueError(f"Unsupported dtype: {weight.dtype}")
49
+
50
+ for k, v in fused.items():
51
+ yield k, v.to(device=original_device)
52
+
53
+
54
+ def apply_loras(
55
+ model_sd: StateDict,
56
+ lora_sd_and_strengths: list[LoraStateDictWithStrength],
57
+ dtype: torch.dtype | None = None,
58
+ destination_sd: StateDict | None = None,
59
+ ) -> StateDict:
60
+ if destination_sd is not None:
61
+ sd = destination_sd.sd
62
+ for key, tensor in fuse_lora_weights(model_sd, lora_sd_and_strengths, dtype):
63
+ sd[key] = tensor
64
+ return destination_sd
65
+
66
+ fused = dict(fuse_lora_weights(model_sd, lora_sd_and_strengths, dtype))
67
+ sd = {k: (fused[k] if k in fused else v.clone()) for k, v in model_sd.sd.items()}
68
+ return StateDict(sd, model_sd.device, model_sd.size, model_sd.dtype)
69
+
70
+
71
+ def _prepare_deltas(
72
+ lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
73
+ ) -> torch.Tensor | None:
74
+ deltas = []
75
+ prefix = key[: -len(".weight")]
76
+ key_a = f"{prefix}.lora_A.weight"
77
+ key_b = f"{prefix}.lora_B.weight"
78
+ for lsd, coef in lora_sd_and_strengths:
79
+ if key_a not in lsd.sd or key_b not in lsd.sd:
80
+ continue
81
+ a = lsd.sd[key_a].to(device=device)
82
+ b = lsd.sd[key_b].to(device=device)
83
+ product = torch.matmul(b * coef, a)
84
+ del a, b
85
+ deltas.append(product.to(dtype=dtype))
86
+ if len(deltas) == 0:
87
+ return None
88
+ elif len(deltas) == 1:
89
+ return deltas[0]
90
+ return torch.sum(torch.stack(deltas, dim=0), dim=0)
91
+
92
+
93
+ def _fuse_delta_with_scaled_fp8(
94
+ deltas: torch.Tensor,
95
+ weight: torch.Tensor,
96
+ key: str,
97
+ scale_key: str,
98
+ model_sd: StateDict,
99
+ ) -> dict[str, torch.Tensor]:
100
+ """Dequantize scaled FP8 weight, add LoRA delta, and re-quantize."""
101
+ weight_scale = model_sd.sd[scale_key]
102
+
103
+ original_weight = weight.t().to(torch.float32) * weight_scale
104
+
105
+ new_weight = original_weight + deltas.to(torch.float32)
106
+
107
+ new_fp8_weight, new_weight_scale = quantize_weight_to_fp8_per_tensor(new_weight)
108
+ return {key: new_fp8_weight, scale_key: new_weight_scale}
109
+
110
+
111
+ def _fuse_delta_with_cast_fp8(
112
+ deltas: torch.Tensor,
113
+ weight: torch.Tensor,
114
+ key: str,
115
+ target_dtype: torch.dtype,
116
+ ) -> dict[str, torch.Tensor]:
117
+ """Fuse LoRA delta with cast-only FP8 weight (no scale factor)."""
118
+ if str(weight.device).startswith("cuda"):
119
+ _fused_add_round_launch(deltas, weight, seed=0)
120
+ else:
121
+ deltas.add_(weight.to(dtype=deltas.dtype))
122
+ return {key: deltas.to(dtype=target_dtype)}
123
+
124
+
125
+ def _fuse_delta_with_bfloat16(
126
+ deltas: torch.Tensor,
127
+ weight: torch.Tensor,
128
+ key: str,
129
+ target_dtype: torch.dtype,
130
+ ) -> dict[str, torch.Tensor]:
131
+ """Fuse LoRA delta with bfloat16 weight."""
132
+ deltas.add_(weight)
133
+ return {key: deltas.to(dtype=target_dtype)}
ltx2/ltx_core/loader/kernels.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: ANN001, ANN201, ERA001, N803, N806
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def fused_add_round_kernel(
8
+ x_ptr,
9
+ output_ptr, # contents will be added to the output
10
+ seed,
11
+ n_elements,
12
+ EXPONENT_BIAS,
13
+ MANTISSA_BITS,
14
+ BLOCK_SIZE: tl.constexpr,
15
+ ):
16
+ """
17
+ A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
18
+ and add them to bfloat16 output weights. Might be used to upcast original model weights
19
+ and to further add them to precalculated deltas coming from LoRAs.
20
+ """
21
+ # Get program ID and compute offsets
22
+ pid = tl.program_id(axis=0)
23
+ block_start = pid * BLOCK_SIZE
24
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ # Load data
28
+ x = tl.load(x_ptr + offsets, mask=mask)
29
+ rand_vals = tl.rand(seed, offsets) - 0.5
30
+
31
+ x = tl.cast(x, tl.float16)
32
+ delta = tl.load(output_ptr + offsets, mask=mask)
33
+ delta = tl.cast(delta, tl.float16)
34
+ x = x + delta
35
+
36
+ x_bits = tl.cast(x, tl.int16, bitcast=True)
37
+
38
+ # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
39
+ # normal numbers and -14 for subnormals.
40
+ fp16_exponent_bits = (x_bits & 0x7C00) >> 10
41
+ fp16_normals = fp16_exponent_bits > 0
42
+ fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
43
+
44
+ # Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
45
+ exponent = fp16_exponent + EXPONENT_BIAS
46
+ MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
47
+ exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
48
+ exponent = tl.where(exponent < 0, 0, exponent)
49
+
50
+ # Normal ULP exponent, expressed as an fp16 exponent field:
51
+ # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
52
+ # Simplifies to: fp16_exponent - MANTISSA_BITS + 15
53
+ # See https://en.wikipedia.org/wiki/Unit_in_the_last_place
54
+ eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
55
+
56
+ # Calculate epsilon in the target dtype
57
+ eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
58
+
59
+ # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
60
+ # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
61
+ # 16 - EXPONENT_BIAS - MANTISSA_BITS
62
+ eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
63
+ eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
64
+
65
+ # Apply zero mask to epsilon
66
+ eps = tl.where(x == 0, 0.0, eps)
67
+
68
+ # Apply stochastic rounding
69
+ output = tl.cast(x + rand_vals * eps, tl.bfloat16)
70
+
71
+ # Store the result
72
+ tl.store(output_ptr + offsets, output, mask=mask)
ltx2/ltx_core/loader/module_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, NamedTuple
2
+
3
+ import torch
4
+
5
+
6
+ class ModuleOps(NamedTuple):
7
+ """
8
+ Defines a named operation for matching and mutating PyTorch modules.
9
+ Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
10
+ """
11
+
12
+ name: str
13
+ matcher: Callable[[torch.nn.Module], bool]
14
+ mutator: Callable[[torch.nn.Module], torch.nn.Module]
ltx2/ltx_core/loader/primitives.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, NamedTuple, Protocol
5
+
6
+ import torch
7
+
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.loader.sd_ops import SDOps
10
+ from ltx_core.model.model_protocol import ModelType
11
+
12
+ if TYPE_CHECKING:
13
+ from ltx_core.loader.registry import Registry
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class StateDict:
18
+ """
19
+ Immutable container for a PyTorch state dictionary.
20
+ Contains:
21
+ - sd: Dictionary of tensors (weights, buffers, etc.)
22
+ - device: Device where tensors are stored
23
+ - size: Total memory footprint in bytes
24
+ - dtype: Set of tensor dtypes present
25
+ """
26
+
27
+ sd: dict
28
+ device: torch.device
29
+ size: int
30
+ dtype: set[torch.dtype]
31
+
32
+ def footprint(self) -> tuple[int, torch.device]:
33
+ return self.size, self.device
34
+
35
+
36
+ class StateDictLoader(Protocol):
37
+ """
38
+ Protocol for loading state dictionaries from various sources.
39
+ Implementations must provide:
40
+ - metadata: Extract model metadata from a single path
41
+ - load: Load state dict from path(s) and apply SDOps transformations
42
+ """
43
+
44
+ def metadata(self, path: str) -> dict:
45
+ """
46
+ Load metadata from path
47
+ """
48
+
49
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
50
+ """
51
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
52
+ """
53
+
54
+
55
+ class ModelBuilderProtocol(Protocol[ModelType]):
56
+ """
57
+ Protocol for building PyTorch models from configuration dictionaries.
58
+ Implementations must provide:
59
+ - meta_model: Create a model from configuration dictionary and apply module operations
60
+ - build: Create and initialize a model from state dictionary and apply dtype transformations
61
+ """
62
+
63
+ model_sd_ops: SDOps | None
64
+ module_ops: tuple[ModuleOps, ...]
65
+ loras: tuple["LoraPathStrengthAndSDOps", ...]
66
+ registry: "Registry"
67
+
68
+ def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
69
+ """
70
+ Create a model on the meta device from a configuration dictionary.
71
+ This decouples model creation from weight loading, allowing the model
72
+ architecture to be instantiated without allocating memory for parameters.
73
+ Args:
74
+ config: Model configuration dictionary.
75
+ module_ops: Optional list of module operations to apply (e.g., quantization).
76
+ Returns:
77
+ Model instance on meta device (no actual memory allocated for parameters).
78
+ """
79
+ ...
80
+
81
+ def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]":
82
+ """Return a copy of this builder with the given state-dict key remapping ops."""
83
+ ...
84
+
85
+ def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]":
86
+ """Return a copy of this builder with the given module operations (e.g. quantization)."""
87
+ ...
88
+
89
+ def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]":
90
+ """Return a copy of this builder with the given LoRAs to fuse at build time."""
91
+ ...
92
+
93
+ def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]":
94
+ """Return a copy of this builder using the given weight registry for allocation."""
95
+ ...
96
+
97
+ def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]":
98
+ """Return a copy of this builder that loads LoRA weights onto the given device."""
99
+ ...
100
+
101
+ def build(
102
+ self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object
103
+ ) -> ModelType:
104
+ """
105
+ Build the model
106
+ Args:
107
+ device: Target device for the model
108
+ dtype: Target dtype for the model, if None, uses the dtype of the model_path model
109
+ Returns:
110
+ Model instance
111
+ """
112
+ ...
113
+
114
+ def model_config(self) -> dict:
115
+ """Return the model configuration dictionary extracted from the checkpoint metadata."""
116
+ ...
117
+
118
+
119
+ class LoRAAdaptableProtocol(Protocol):
120
+ """
121
+ Protocol for models that can be adapted with LoRAs.
122
+ Implementations must provide:
123
+ - lora: Add a LoRA to the model
124
+ """
125
+
126
+ def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
127
+ pass
128
+
129
+
130
+ class LoraPathStrengthAndSDOps(NamedTuple):
131
+ """
132
+ Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
133
+ """
134
+
135
+ path: str
136
+ strength: float
137
+ sd_ops: SDOps
138
+
139
+
140
+ class LoraStateDictWithStrength(NamedTuple):
141
+ """
142
+ Tuple containing a LoRA state dict and strength for applying to the model.
143
+ """
144
+
145
+ state_dict: StateDict
146
+ strength: float
ltx2/ltx_core/loader/registry.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import threading
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Protocol
6
+
7
+ from ltx_core.loader.primitives import StateDict
8
+ from ltx_core.loader.sd_ops import SDOps
9
+
10
+
11
+ class Registry(Protocol):
12
+ """
13
+ Protocol for managing state dictionaries in a registry.
14
+ It is used to store state dictionaries and reuse them later without loading them again.
15
+ Implementations must provide:
16
+ - add: Add a state dictionary to the registry
17
+ - pop: Remove a state dictionary from the registry
18
+ - get: Retrieve a state dictionary from the registry
19
+ - clear: Clear all state dictionaries from the registry
20
+ """
21
+
22
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
23
+
24
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
25
+
26
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
27
+
28
+ def clear(self) -> None: ...
29
+
30
+
31
+ class DummyRegistry(Registry):
32
+ """
33
+ Dummy registry that does not store state dictionaries.
34
+ """
35
+
36
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
37
+ pass
38
+
39
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
40
+ pass
41
+
42
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
43
+ pass
44
+
45
+ def clear(self) -> None:
46
+ pass
47
+
48
+
49
+ @dataclass
50
+ class StateDictRegistry(Registry):
51
+ """
52
+ Registry that stores state dictionaries in a dictionary.
53
+ """
54
+
55
+ _state_dicts: dict[str, StateDict] = field(default_factory=dict)
56
+ _lock: threading.Lock = field(default_factory=threading.Lock)
57
+
58
+ def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
59
+ m = hashlib.sha256()
60
+ parts = [str(Path(p).resolve()) for p in paths]
61
+ if sd_ops is not None:
62
+ parts.append(sd_ops.name)
63
+ m.update("\0".join(parts).encode("utf-8"))
64
+ return m.hexdigest()
65
+
66
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
67
+ sd_id = self._generate_id(paths, sd_ops)
68
+ with self._lock:
69
+ if sd_id in self._state_dicts:
70
+ raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
71
+ self._state_dicts[sd_id] = state_dict
72
+ return sd_id
73
+
74
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
75
+ with self._lock:
76
+ return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
77
+
78
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
79
+ with self._lock:
80
+ return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
81
+
82
+ def clear(self) -> None:
83
+ with self._lock:
84
+ self._state_dicts.clear()
ltx2/ltx_core/loader/sd_ops.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class ContentReplacement:
9
+ """
10
+ Represents a content replacement operation.
11
+ Used to replace a specific content with a replacement in a state dict key.
12
+ """
13
+
14
+ content: str
15
+ replacement: str
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class ContentMatching:
20
+ """
21
+ Represents a content matching operation.
22
+ Used to match a specific prefix and suffix in a state dict key.
23
+ """
24
+
25
+ prefix: str = ""
26
+ suffix: str = ""
27
+
28
+
29
+ class KeyValueOperationResult(NamedTuple):
30
+ """
31
+ Represents the result of a key-value operation.
32
+ Contains the new key and value after the operation has been applied.
33
+ """
34
+
35
+ new_key: str
36
+ new_value: torch.Tensor
37
+
38
+
39
+ class KeyValueOperation(Protocol):
40
+ """
41
+ Protocol for key-value operations.
42
+ Used to apply operations to a specific key and value in a state dict.
43
+ """
44
+
45
+ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
46
+
47
+
48
+ @dataclass(frozen=True, slots=True)
49
+ class SDKeyValueOperation:
50
+ """
51
+ Represents a key-value operation.
52
+ Used to apply operations to a specific key and value in a state dict.
53
+ """
54
+
55
+ key_matcher: ContentMatching
56
+ kv_operation: KeyValueOperation
57
+
58
+
59
+ @dataclass(frozen=True, slots=True)
60
+ class SDOps:
61
+ """Immutable class representing state dict key operations."""
62
+
63
+ name: str
64
+ mapping: tuple[
65
+ ContentReplacement | ContentMatching | SDKeyValueOperation, ...
66
+ ] = () # Immutable tuple of (key, value) pairs
67
+ allowed_keys: frozenset[str] | None = None
68
+
69
+ def with_replacement(self, content: str, replacement: str) -> "SDOps":
70
+ """Create a new SDOps instance with the specified replacement added to the mapping."""
71
+
72
+ new_mapping = (*self.mapping, ContentReplacement(content, replacement))
73
+ return replace(self, mapping=new_mapping)
74
+
75
+ def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
76
+ """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
77
+
78
+ new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
79
+ return replace(self, mapping=new_mapping)
80
+
81
+ def with_additional_allowed_keys(self, keys: frozenset[str]) -> "SDOps":
82
+ """Create a new SDOps instance that only passes keys present in *keys* (post-replacement).
83
+ If allowed_keys already exists, the sets are merged via union.
84
+ """
85
+ merged = frozenset(keys) | self.allowed_keys if self.allowed_keys is not None else frozenset(keys)
86
+ return replace(self, allowed_keys=merged)
87
+
88
+ def with_kv_operation(
89
+ self,
90
+ operation: KeyValueOperation,
91
+ key_prefix: str = "",
92
+ key_suffix: str = "",
93
+ ) -> "SDOps":
94
+ """Create a new SDOps instance with the specified value operation added to the mapping."""
95
+ key_matcher = ContentMatching(key_prefix, key_suffix)
96
+ sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
97
+ new_mapping = (*self.mapping, sd_kv_operation)
98
+ return replace(self, mapping=new_mapping)
99
+
100
+ def apply_to_key(self, key: str) -> str | None:
101
+ """Apply the mapping to the given name."""
102
+ matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
103
+ valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
104
+ if not valid:
105
+ return None
106
+
107
+ for replacement in self.mapping:
108
+ if not isinstance(replacement, ContentReplacement):
109
+ continue
110
+ if replacement.content in key:
111
+ key = key.replace(replacement.content, replacement.replacement)
112
+
113
+ if self.allowed_keys is not None and key not in self.allowed_keys:
114
+ return None
115
+
116
+ return key
117
+
118
+ def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
119
+ """Apply the value operation to the given name and associated value."""
120
+ for operation in self.mapping:
121
+ if not isinstance(operation, SDKeyValueOperation):
122
+ continue
123
+ if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
124
+ return operation.kv_operation(key, value)
125
+ return [KeyValueOperationResult(key, value)]
126
+
127
+
128
+ # Predefined SDOps instances
129
+ LTXV_LORA_COMFY_RENAMING_MAP = (
130
+ SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
131
+ )
132
+
133
+ LTXV_LORA_COMFY_TARGET_MAP = (
134
+ SDOps("LTXV_LORA_COMFY_TARGET_MAP")
135
+ .with_matching()
136
+ .with_replacement("diffusion_model.", "")
137
+ .with_replacement(".lora_A.weight", ".weight")
138
+ .with_replacement(".lora_B.weight", ".weight")
139
+ )
ltx2/ltx_core/loader/sft_loader.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import safetensors
4
+ import torch
5
+
6
+ from ltx_core.loader.primitives import StateDict, StateDictLoader
7
+ from ltx_core.loader.sd_ops import SDOps
8
+
9
+
10
+ class SafetensorsStateDictLoader(StateDictLoader):
11
+ """
12
+ Loads weights from safetensors files without metadata support.
13
+ Use this for loading raw weight files. For model files that include
14
+ configuration metadata, use SafetensorsModelStateDictLoader instead.
15
+ """
16
+
17
+ def metadata(self, path: str) -> dict:
18
+ raise NotImplementedError("Not implemented")
19
+
20
+ def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
21
+ """
22
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
23
+ """
24
+ sd = {}
25
+ size = 0
26
+ dtype = set()
27
+ device = device or torch.device("cpu")
28
+ model_paths = path if isinstance(path, list) else [path]
29
+ for shard_path in model_paths:
30
+ with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
31
+ safetensor_keys = f.keys()
32
+ for name in safetensor_keys:
33
+ expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
34
+ if expected_name is None:
35
+ continue
36
+ value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
37
+ key_value_pairs = ((expected_name, value),)
38
+ if sd_ops is not None:
39
+ key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
40
+ for key, value in key_value_pairs:
41
+ size += value.nbytes
42
+ dtype.add(value.dtype)
43
+ sd[key] = value
44
+
45
+ return StateDict(sd=sd, device=device, size=size, dtype=dtype)
46
+
47
+
48
+ class SafetensorsModelStateDictLoader(StateDictLoader):
49
+ """
50
+ Loads weights and configuration metadata from safetensors model files.
51
+ Unlike SafetensorsStateDictLoader, this loader can read model configuration
52
+ from the safetensors file metadata via the metadata() method.
53
+ """
54
+
55
+ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
56
+ self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
57
+
58
+ def metadata(self, path: str) -> dict:
59
+ with safetensors.safe_open(path, framework="pt") as f:
60
+ meta = f.metadata()
61
+ if meta is None or "config" not in meta:
62
+ return {}
63
+ return json.loads(meta["config"])
64
+
65
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
66
+ return self.weight_loader.load(path, sd_ops, device)
ltx2/ltx_core/loader/single_gpu_model_builder.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field, replace
3
+ from typing import Generic
4
+
5
+ import torch
6
+
7
+ from ltx_core.loader.fuse_loras import apply_loras
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.loader.primitives import (
10
+ LoRAAdaptableProtocol,
11
+ LoraPathStrengthAndSDOps,
12
+ LoraStateDictWithStrength,
13
+ ModelBuilderProtocol,
14
+ StateDict,
15
+ StateDictLoader,
16
+ )
17
+ from ltx_core.loader.registry import DummyRegistry, Registry
18
+ from ltx_core.loader.sd_ops import SDOps
19
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
20
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
21
+
22
+ logger: logging.Logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
27
+ """
28
+ Builder for PyTorch models residing on a single GPU.
29
+ Attributes:
30
+ model_class_configurator: Class responsible for constructing the model from a config dict.
31
+ model_path: Path (or tuple of shard paths) to the model's `.safetensors` checkpoint(s).
32
+ model_sd_ops: Optional state-dict operations applied when loading the model weights.
33
+ module_ops: Sequence of module-level mutations applied to the meta model before weight loading.
34
+ loras: Sequence of LoRA adapters (path, strength, optional sd_ops) to fuse into the model.
35
+ model_loader: Strategy for loading state dicts from disk. Defaults to
36
+ :class:`SafetensorsModelStateDictLoader`.
37
+ registry: Cache for already-loaded state dicts. Defaults to :class:`DummyRegistry` (no caching).
38
+ lora_load_device: Device used when loading LoRA weight tensors from disk. Defaults to
39
+ ``torch.device("cpu")``, which keeps LoRA weights in CPU memory and transfers them to
40
+ the target GPU sequentially during fusion, reducing peak GPU memory usage compared to
41
+ loading all LoRA weights directly onto the GPU at once.
42
+ """
43
+
44
+ model_class_configurator: type[ModelConfigurator[ModelType]]
45
+ model_path: str | tuple[str, ...]
46
+ model_sd_ops: SDOps | None = None
47
+ module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
48
+ loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
49
+ model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
50
+ registry: Registry = field(default_factory=DummyRegistry)
51
+ lora_load_device: torch.device = field(default_factory=lambda: torch.device("cpu"))
52
+
53
+ def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
54
+ return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
55
+
56
+ def with_sd_ops(self, sd_ops: SDOps | None) -> "SingleGPUModelBuilder":
57
+ return replace(self, model_sd_ops=sd_ops)
58
+
59
+ def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "SingleGPUModelBuilder":
60
+ return replace(self, module_ops=module_ops)
61
+
62
+ def with_loras(self, loras: tuple[LoraPathStrengthAndSDOps, ...]) -> "SingleGPUModelBuilder":
63
+ return replace(self, loras=loras)
64
+
65
+ def with_registry(self, registry: Registry) -> "SingleGPUModelBuilder":
66
+ return replace(self, registry=registry)
67
+
68
+ def with_lora_load_device(self, device: torch.device) -> "SingleGPUModelBuilder":
69
+ return replace(self, lora_load_device=device)
70
+
71
+ def model_config(self) -> dict:
72
+ first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
73
+ return self.model_loader.metadata(first_shard_path)
74
+
75
+ def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
76
+ with torch.device("meta"):
77
+ model = self.model_class_configurator.from_config(config)
78
+ for module_op in module_ops:
79
+ if module_op.matcher(model):
80
+ model = module_op.mutator(model)
81
+ return model
82
+
83
+ def load_sd(
84
+ self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
85
+ ) -> StateDict:
86
+ state_dict = registry.get(paths, sd_ops)
87
+ if state_dict is None:
88
+ state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
89
+ registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
90
+ return state_dict
91
+
92
+ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
93
+ uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
94
+ uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
95
+ if uninitialized_params or uninitialized_buffers:
96
+ logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
97
+ return meta_model
98
+ retval = meta_model.to(device)
99
+ return retval
100
+
101
+ def build(
102
+ self,
103
+ device: torch.device | None = None,
104
+ dtype: torch.dtype | None = None,
105
+ **kwargs: object, # noqa: ARG002
106
+ ) -> ModelType:
107
+ device = torch.device("cuda") if device is None else device
108
+ config = self.model_config()
109
+ meta_model = self.meta_model(config, self.module_ops)
110
+ model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path]
111
+ model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
112
+
113
+ lora_strengths = [lora.strength for lora in self.loras]
114
+ if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
115
+ sd = model_state_dict.sd
116
+ if dtype is not None:
117
+ sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
118
+ meta_model.load_state_dict(sd, strict=False, assign=True)
119
+ return self._return_model(meta_model, device)
120
+
121
+ lora_state_dicts = [
122
+ self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=self.lora_load_device)
123
+ for lora in self.loras
124
+ ]
125
+ lora_sd_and_strengths = [
126
+ LoraStateDictWithStrength(sd, strength)
127
+ for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
128
+ ]
129
+ final_sd = apply_loras(
130
+ model_sd=model_state_dict,
131
+ lora_sd_and_strengths=lora_sd_and_strengths,
132
+ dtype=dtype,
133
+ destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
134
+ )
135
+ meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
136
+ return self._return_model(meta_model, device)
ltx2/ltx_core/modality_tiling.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video modality tiling helpers.
2
+ Provides :class:`VideoModalityTilingHelper` — a stateless helper that
3
+ tiles and blends video :class:`Modality` token sequences by
4
+ spatial/temporal region. Tile geometry is represented by the existing
5
+ :class:`Tile` NamedTuple from :mod:`ltx_core.tiling`; no distributed
6
+ primitives are required.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass, replace
12
+
13
+ import torch
14
+
15
+ from ltx_core.model.transformer.modality import Modality
16
+ from ltx_core.tiling import Tile, TileCountConfig, create_tiles, identity_mapping_operation, split_by_count
17
+ from ltx_core.tools import VideoLatentTools
18
+ from ltx_core.types import VideoLatentShape
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class TilingContext:
23
+ """Opaque context produced by :meth:`VideoModalityTilingHelper.tile_modality`.
24
+ Carries the token-level keep mask and per-conditioning-token blend
25
+ weights needed by :meth:`~VideoModalityTilingHelper.blend`.
26
+ """
27
+
28
+ keep_mask: torch.Tensor
29
+ cond_blend_weights: torch.Tensor | None
30
+ """``(num_kept_cond,)`` — weight for each kept conditioning token,
31
+ equal to ``1 / num_tiles_that_keep_this_token``. ``None`` when
32
+ there are no conditioning tokens."""
33
+
34
+
35
+ class VideoModalityTilingHelper:
36
+ """Stateless helper that tiles and blends video :class:`Modality` sequences.
37
+ Constructed once with a :class:`TileCountConfig` and
38
+ :class:`VideoLatentTools`. Tiles are computed at construction and
39
+ available via the :attr:`tiles` property. Use :meth:`tile_modality`
40
+ and :meth:`blend` with any tile from that list.
41
+ Usage::
42
+ helper = VideoModalityTilingHelper(tiling, video_tools)
43
+ for tile in helper.tiles:
44
+ tiled_mod, ctx = helper.tile_modality(modality, tile)
45
+ result = run_model(tiled_mod)
46
+ helper.blend(result, tile, ctx, output=output)
47
+ """
48
+
49
+ def __init__(self, tiling: TileCountConfig, video_tools: VideoLatentTools) -> None:
50
+ self._patchifier = video_tools.patchifier
51
+ self._latent_shape = video_tools.target_shape
52
+ self._num_generated_tokens = self._patchifier.get_token_count(self._latent_shape)
53
+ self._tiles = create_tiles(
54
+ torch.Size([self._latent_shape.frames, self._latent_shape.height, self._latent_shape.width]),
55
+ splitters=[
56
+ split_by_count(tiling.frames.num_tiles, tiling.frames.overlap),
57
+ split_by_count(tiling.height.num_tiles, tiling.height.overlap),
58
+ split_by_count(tiling.width.num_tiles, tiling.width.overlap),
59
+ ],
60
+ mappers=[identity_mapping_operation] * 3,
61
+ )
62
+
63
+ @property
64
+ def tiles(self) -> list[Tile]:
65
+ """All tiles for the configured tiling layout."""
66
+ return self._tiles
67
+
68
+ # -- tile modality -----------------------------------------------------
69
+
70
+ def tile_modality(self, modality: Modality, tile: Tile) -> tuple[Modality, TilingContext]:
71
+ """Slice *modality* to the tokens covered by *tile*.
72
+ Selects generated tokens belonging to the tile's spatial region
73
+ and conditioning tokens that overlap with the tile (or have
74
+ negative time coordinates).
75
+ Returns:
76
+ A ``(tiled_modality, context)`` tuple. Pass *context* to
77
+ :meth:`blend` together with the model output.
78
+ """
79
+ keep_mask = self._keep_mask(modality, tile)
80
+
81
+ tile_attention_mask = None
82
+ if modality.attention_mask is not None:
83
+ keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(1)
84
+ tile_attention_mask = modality.attention_mask[:, keep_indices, :][:, :, keep_indices]
85
+
86
+ tiled = replace(
87
+ modality,
88
+ latent=modality.latent[:, keep_mask, :],
89
+ timesteps=modality.timesteps[:, keep_mask],
90
+ positions=modality.positions[:, :, keep_mask, :],
91
+ attention_mask=tile_attention_mask,
92
+ )
93
+
94
+ cond_blend_weights = None
95
+ num_total = modality.latent.shape[1]
96
+ if num_total > self._num_generated_tokens:
97
+ cond_keep = keep_mask[self._num_generated_tokens :]
98
+ # Count how many tiles keep each conditioning token.
99
+ cond_counts = torch.zeros(cond_keep.sum(), dtype=torch.float32)
100
+ for t in self._tiles:
101
+ other_mask = self._keep_mask(modality, t)
102
+ other_cond = other_mask[self._num_generated_tokens :]
103
+ # Map other tile's kept cond tokens into this tile's kept subset.
104
+ cond_counts += other_cond[cond_keep].float()
105
+ cond_blend_weights = 1.0 / cond_counts
106
+
107
+ return tiled, TilingContext(keep_mask=keep_mask, cond_blend_weights=cond_blend_weights)
108
+
109
+ # -- blend -------------------------------------------------------------
110
+
111
+ def blend(
112
+ self,
113
+ tile_to_blend: torch.Tensor,
114
+ tile: Tile,
115
+ context: TilingContext,
116
+ output: torch.Tensor | None = None,
117
+ ) -> torch.Tensor:
118
+ """Blend-weight tile results and accumulate into the full token space.
119
+ Premultiplied (blend-weighted) data is **added** to *output*,
120
+ allowing multiple tiles to be accumulated into the same buffer.
121
+ Args:
122
+ tile_to_blend: Denoised tile tensor ``(B, num_tile_tokens, D)``,
123
+ where the first ``_tile_generated_token_count(tile)``
124
+ entries are generated tokens and the remainder are
125
+ conditioning tokens.
126
+ tile: The :class:`Tile` that was used in :meth:`tile_modality`.
127
+ context: The :class:`TilingContext` returned by :meth:`tile_modality`.
128
+ output: Optional pre-allocated output tensor. When provided
129
+ its shape must be ``(B, num_total_tokens, D)`` and the
130
+ blended tile is **added** into it. When ``None`` a new
131
+ zero-filled tensor is created.
132
+ Returns:
133
+ The output tensor with the blended tile added at the correct
134
+ positions.
135
+ """
136
+ batch, _, dim = tile_to_blend.shape
137
+ num_tile_gen = self._tile_generated_token_count(tile)
138
+ gen_indices = self._generated_token_indices(tile)
139
+
140
+ num_total_tokens = context.keep_mask.shape[0]
141
+ expected_shape = (batch, num_total_tokens, dim)
142
+
143
+ if output is not None:
144
+ if output.shape != expected_shape:
145
+ raise ValueError(f"Expected output shape {expected_shape}, got {output.shape}")
146
+ result = output
147
+ else:
148
+ result = torch.zeros(*expected_shape, device=tile_to_blend.device, dtype=tile_to_blend.dtype)
149
+
150
+ # Blend mask is (tile_F, tile_H, tile_W) — one weight per token in row-major order.
151
+ blend_weights = tile.blend_mask.reshape(-1).to(device=tile_to_blend.device, dtype=tile_to_blend.dtype)
152
+ tile_gen = tile_to_blend[:, :num_tile_gen, :] * blend_weights[None, :, None]
153
+
154
+ result[:, gen_indices, :] += tile_gen
155
+
156
+ # Scatter kept conditioning tokens, weighted by 1/N where N is
157
+ # the number of tiles that keep each token (so they sum to 1).
158
+ if num_total_tokens > self._num_generated_tokens and context.cond_blend_weights is not None:
159
+ cond_keep = context.keep_mask[self._num_generated_tokens :]
160
+ cond_indices = self._num_generated_tokens + cond_keep.nonzero(as_tuple=False).squeeze(1)
161
+ weights = context.cond_blend_weights.to(device=tile_to_blend.device, dtype=tile_to_blend.dtype)
162
+ result[:, cond_indices, :] += tile_to_blend[:, num_tile_gen:, :] * weights[None, :, None]
163
+
164
+ return result
165
+
166
+ # -- private -----------------------------------------------------------
167
+
168
+ def _tile_generated_token_count(self, tile: Tile) -> int:
169
+ """Number of generated tokens in *tile*."""
170
+ frame_slice, height_slice, width_slice = tile.in_coords
171
+ tile_shape = VideoLatentShape(
172
+ batch=self._latent_shape.batch,
173
+ channels=self._latent_shape.channels,
174
+ frames=frame_slice.stop - frame_slice.start,
175
+ height=height_slice.stop - height_slice.start,
176
+ width=width_slice.stop - width_slice.start,
177
+ )
178
+ return self._patchifier.get_token_count(tile_shape)
179
+
180
+ def _generated_token_indices(self, tile: Tile) -> torch.Tensor:
181
+ """Flat token indices of *tile*'s generated tokens in the full sequence."""
182
+ frame_slice, height_slice, width_slice = tile.in_coords
183
+ f = torch.arange(frame_slice.start, frame_slice.stop)
184
+ h = torch.arange(height_slice.start, height_slice.stop)
185
+ w = torch.arange(width_slice.start, width_slice.stop)
186
+ return (
187
+ f[:, None, None] * self._latent_shape.height * self._latent_shape.width
188
+ + h[None, :, None] * self._latent_shape.width
189
+ + w[None, None, :]
190
+ ).reshape(-1)
191
+
192
+ def _keep_mask(self, modality: Modality, tile: Tile) -> torch.Tensor:
193
+ """Boolean mask ``(num_total_tokens,)`` — True for tokens the tile processes.
194
+ Generated tokens are selected by grid position. Conditioning
195
+ tokens are kept when their ``[start, end)`` intervals overlap
196
+ the tile in all three dimensions, or when they have a negative
197
+ time coordinate (reference tokens).
198
+ """
199
+ num_total = modality.latent.shape[1]
200
+ mask = torch.zeros(num_total, dtype=torch.bool)
201
+
202
+ gen_indices = self._generated_token_indices(tile)
203
+ mask[gen_indices] = True
204
+
205
+ if num_total > self._num_generated_tokens:
206
+ gen_positions = modality.positions[:, :, gen_indices, :] # (B, 3, num_tile_gen, 2)
207
+ tile_start = gen_positions[..., 0].amin(dim=2) # (B, 3)
208
+ tile_end = gen_positions[..., 1].amax(dim=2) # (B, 3)
209
+
210
+ cond_positions = modality.positions[:, :, self._num_generated_tokens :, :] # (B, 3, num_cond, 2)
211
+
212
+ overlaps = (cond_positions[..., 0] < tile_end.unsqueeze(2)) & (
213
+ cond_positions[..., 1] > tile_start.unsqueeze(2)
214
+ ) # (B, 3, num_cond)
215
+ overlaps_all_dims = overlaps.all(dim=1) # (B, num_cond)
216
+
217
+ has_negative_time = cond_positions[:, 0, :, 0] < 0 # (B, num_cond)
218
+
219
+ keep_cond = (overlaps_all_dims | has_negative_time).any(dim=0) # (num_cond,)
220
+ mask[self._num_generated_tokens :] = keep_cond
221
+
222
+ return mask
ltx2/ltx_core/model/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for LTX-2."""
2
+
3
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
4
+
5
+ __all__ = [
6
+ "ModelConfigurator",
7
+ "ModelType",
8
+ ]
ltx2/ltx_core/model/audio_vae/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio VAE model components."""
2
+
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio
4
+ from ltx_core.model.audio_vae.model_configurator import (
5
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
6
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
7
+ VOCODER_COMFY_KEYS_FILTER,
8
+ AudioDecoderConfigurator,
9
+ AudioEncoderConfigurator,
10
+ VocoderConfigurator,
11
+ )
12
+ from ltx_core.model.audio_vae.ops import AudioProcessor
13
+ from ltx_core.model.audio_vae.vocoder import Vocoder, VocoderWithBWE
14
+
15
+ __all__ = [
16
+ "AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
17
+ "AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
18
+ "VOCODER_COMFY_KEYS_FILTER",
19
+ "AudioDecoder",
20
+ "AudioDecoderConfigurator",
21
+ "AudioEncoder",
22
+ "AudioEncoderConfigurator",
23
+ "AudioProcessor",
24
+ "Vocoder",
25
+ "VocoderConfigurator",
26
+ "VocoderWithBWE",
27
+ "decode_audio",
28
+ "encode_audio",
29
+ ]
ltx2/ltx_core/model/audio_vae/attention.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
6
+
7
+
8
+ class AttentionType(Enum):
9
+ """Enum for specifying the attention mechanism type."""
10
+
11
+ VANILLA = "vanilla"
12
+ LINEAR = "linear"
13
+ NONE = "none"
14
+
15
+
16
+ class AttnBlock(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ norm_type: NormType = NormType.GROUP,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.in_channels = in_channels
24
+
25
+ self.norm = build_normalization_layer(in_channels, normtype=norm_type)
26
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
27
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
28
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
29
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ h_ = x
33
+ h_ = self.norm(h_)
34
+ q = self.q(h_)
35
+ k = self.k(h_)
36
+ v = self.v(h_)
37
+
38
+ # compute attention
39
+ b, c, h, w = q.shape
40
+ q = q.reshape(b, c, h * w).contiguous()
41
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
42
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
43
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
44
+ w_ = w_ * (int(c) ** (-0.5))
45
+ w_ = torch.nn.functional.softmax(w_, dim=2)
46
+
47
+ # attend to values
48
+ v = v.reshape(b, c, h * w).contiguous()
49
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
50
+ h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
51
+ h_ = h_.reshape(b, c, h, w).contiguous()
52
+
53
+ h_ = self.proj_out(h_)
54
+
55
+ return x + h_
56
+
57
+
58
+ def make_attn(
59
+ in_channels: int,
60
+ attn_type: AttentionType = AttentionType.VANILLA,
61
+ norm_type: NormType = NormType.GROUP,
62
+ ) -> torch.nn.Module:
63
+ match attn_type:
64
+ case AttentionType.VANILLA:
65
+ return AttnBlock(in_channels, norm_type=norm_type)
66
+ case AttentionType.NONE:
67
+ return torch.nn.Identity()
68
+ case AttentionType.LINEAR:
69
+ raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
70
+ case _:
71
+ raise ValueError(f"Unknown attention type: {attn_type}")
ltx2/ltx_core/model/audio_vae/audio_vae.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ltx_core.components.patchifiers import AudioPatchifier
7
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
8
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
9
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
10
+ from ltx_core.model.audio_vae.downsample import build_downsampling_path
11
+ from ltx_core.model.audio_vae.ops import AudioProcessor, PerChannelStatistics
12
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
13
+ from ltx_core.model.audio_vae.upsample import build_upsampling_path
14
+ from ltx_core.model.audio_vae.vocoder import Vocoder
15
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
16
+ from ltx_core.types import Audio, AudioLatentShape
17
+
18
+ LATENT_DOWNSAMPLE_FACTOR = 4
19
+
20
+
21
+ def build_mid_block(
22
+ channels: int,
23
+ temb_channels: int,
24
+ dropout: float,
25
+ norm_type: NormType,
26
+ causality_axis: CausalityAxis,
27
+ attn_type: AttentionType,
28
+ add_attention: bool,
29
+ ) -> torch.nn.Module:
30
+ """Build the middle block with two ResNet blocks and optional attention."""
31
+ mid = torch.nn.Module()
32
+ mid.block_1 = ResnetBlock(
33
+ in_channels=channels,
34
+ out_channels=channels,
35
+ temb_channels=temb_channels,
36
+ dropout=dropout,
37
+ norm_type=norm_type,
38
+ causality_axis=causality_axis,
39
+ )
40
+ mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
41
+ mid.block_2 = ResnetBlock(
42
+ in_channels=channels,
43
+ out_channels=channels,
44
+ temb_channels=temb_channels,
45
+ dropout=dropout,
46
+ norm_type=norm_type,
47
+ causality_axis=causality_axis,
48
+ )
49
+ return mid
50
+
51
+
52
+ def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
53
+ """Run features through the middle block."""
54
+ features = mid.block_1(features, temb=None)
55
+ features = mid.attn_1(features)
56
+ return mid.block_2(features, temb=None)
57
+
58
+
59
+ class AudioEncoder(torch.nn.Module):
60
+ """
61
+ Encoder that compresses audio spectrograms into latent representations.
62
+ The encoder uses a series of downsampling blocks with residual connections,
63
+ attention mechanisms, and configurable causal convolutions.
64
+ """
65
+
66
+ def __init__( # noqa: PLR0913
67
+ self,
68
+ *,
69
+ ch: int,
70
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
71
+ num_res_blocks: int,
72
+ attn_resolutions: Set[int],
73
+ dropout: float = 0.0,
74
+ resamp_with_conv: bool = True,
75
+ in_channels: int,
76
+ resolution: int,
77
+ z_channels: int,
78
+ double_z: bool = True,
79
+ attn_type: AttentionType = AttentionType.VANILLA,
80
+ mid_block_add_attention: bool = True,
81
+ norm_type: NormType = NormType.GROUP,
82
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
83
+ sample_rate: int = 16000,
84
+ mel_hop_length: int = 160,
85
+ n_fft: int = 1024,
86
+ is_causal: bool = True,
87
+ mel_bins: int = 64,
88
+ **_ignore_kwargs,
89
+ ) -> None:
90
+ """
91
+ Initialize the Encoder.
92
+ Args:
93
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
94
+ (audio_vae.model.params.ddconfig):
95
+ ch: Base number of feature channels used in the first convolution layer.
96
+ ch_mult: Multiplicative factors for the number of channels at each resolution level.
97
+ num_res_blocks: Number of residual blocks to use at each resolution level.
98
+ attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
99
+ resolution: Input spatial resolution of the spectrogram (height, width).
100
+ z_channels: Number of channels in the latent representation.
101
+ norm_type: Normalization layer type to use within the network (e.g., group, batch).
102
+ causality_axis: Axis along which convolutions should be causal (e.g., time axis).
103
+ sample_rate: Audio sample rate in Hz for the input signals.
104
+ mel_hop_length: Hop length used when computing the mel spectrogram.
105
+ n_fft: FFT size used to compute the spectrogram.
106
+ mel_bins: Number of mel-frequency bins in the input spectrogram.
107
+ in_channels: Number of channels in the input spectrogram tensor.
108
+ double_z: If True, predict both mean and log-variance (doubling latent channels).
109
+ is_causal: If True, use causal convolutions suitable for streaming setups.
110
+ dropout: Dropout probability used in residual and mid blocks.
111
+ attn_type: Type of attention mechanism to use in attention blocks.
112
+ resamp_with_conv: If True, perform resolution changes using strided convolutions.
113
+ mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
114
+ """
115
+ super().__init__()
116
+
117
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
118
+ self.sample_rate = sample_rate
119
+ self.mel_hop_length = mel_hop_length
120
+ self.n_fft = n_fft
121
+ self.is_causal = is_causal
122
+ self.mel_bins = mel_bins
123
+
124
+ self.patchifier = AudioPatchifier(
125
+ patch_size=1,
126
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
127
+ sample_rate=sample_rate,
128
+ hop_length=mel_hop_length,
129
+ is_causal=is_causal,
130
+ )
131
+
132
+ self.ch = ch
133
+ self.temb_ch = 0
134
+ self.num_resolutions = len(ch_mult)
135
+ self.num_res_blocks = num_res_blocks
136
+ self.resolution = resolution
137
+ self.in_channels = in_channels
138
+ self.z_channels = z_channels
139
+ self.double_z = double_z
140
+ self.norm_type = norm_type
141
+ self.causality_axis = causality_axis
142
+ self.attn_type = attn_type
143
+
144
+ # downsampling
145
+ self.conv_in = make_conv2d(
146
+ in_channels,
147
+ self.ch,
148
+ kernel_size=3,
149
+ stride=1,
150
+ causality_axis=self.causality_axis,
151
+ )
152
+
153
+ self.non_linearity = torch.nn.SiLU()
154
+
155
+ self.down, block_in = build_downsampling_path(
156
+ ch=ch,
157
+ ch_mult=ch_mult,
158
+ num_resolutions=self.num_resolutions,
159
+ num_res_blocks=num_res_blocks,
160
+ resolution=resolution,
161
+ temb_channels=self.temb_ch,
162
+ dropout=dropout,
163
+ norm_type=self.norm_type,
164
+ causality_axis=self.causality_axis,
165
+ attn_type=self.attn_type,
166
+ attn_resolutions=attn_resolutions,
167
+ resamp_with_conv=resamp_with_conv,
168
+ )
169
+
170
+ self.mid = build_mid_block(
171
+ channels=block_in,
172
+ temb_channels=self.temb_ch,
173
+ dropout=dropout,
174
+ norm_type=self.norm_type,
175
+ causality_axis=self.causality_axis,
176
+ attn_type=self.attn_type,
177
+ add_attention=mid_block_add_attention,
178
+ )
179
+
180
+ self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
181
+ self.conv_out = make_conv2d(
182
+ block_in,
183
+ 2 * z_channels if double_z else z_channels,
184
+ kernel_size=3,
185
+ stride=1,
186
+ causality_axis=self.causality_axis,
187
+ )
188
+
189
+ def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Encode audio spectrogram into latent representations.
192
+ Args:
193
+ spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
194
+ Returns:
195
+ Encoded latent representation of shape (batch, channels, frames, mel_bins)
196
+ """
197
+ h = self.conv_in(spectrogram)
198
+ h = self._run_downsampling_path(h)
199
+ h = run_mid_block(self.mid, h)
200
+ h = self._finalize_output(h)
201
+
202
+ return self._normalize_latents(h)
203
+
204
+ def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
205
+ for level in range(self.num_resolutions):
206
+ stage = self.down[level]
207
+ for block_idx in range(self.num_res_blocks):
208
+ h = stage.block[block_idx](h, temb=None)
209
+ if stage.attn:
210
+ h = stage.attn[block_idx](h)
211
+
212
+ if level != self.num_resolutions - 1:
213
+ h = stage.downsample(h)
214
+
215
+ return h
216
+
217
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
218
+ h = self.norm_out(h)
219
+ h = self.non_linearity(h)
220
+ return self.conv_out(h)
221
+
222
+ def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
223
+ """
224
+ Normalize encoder latents using per-channel statistics.
225
+ When the encoder is configured with ``double_z=True``, the final
226
+ convolution produces twice the number of latent channels, typically
227
+ interpreted as two concatenated tensors along the channel dimension
228
+ (e.g., mean and variance or other auxiliary parameters).
229
+ This method intentionally uses only the first half of the channels
230
+ (the "mean" component) as input to the patchifier and normalization
231
+ logic. The remaining channels are left unchanged by this method and
232
+ are expected to be consumed elsewhere in the VAE pipeline.
233
+ If ``double_z=False``, the encoder output already contains only the
234
+ mean latents and the chunking operation simply returns that tensor.
235
+ """
236
+ means = torch.chunk(latent_output, 2, dim=1)[0]
237
+ latent_shape = AudioLatentShape(
238
+ batch=means.shape[0],
239
+ channels=means.shape[1],
240
+ frames=means.shape[2],
241
+ mel_bins=means.shape[3],
242
+ )
243
+ latent_patched = self.patchifier.patchify(means)
244
+ latent_normalized = self.per_channel_statistics.normalize(latent_patched)
245
+ return self.patchifier.unpatchify(latent_normalized, latent_shape)
246
+
247
+
248
+ def encode_audio(
249
+ audio: Audio,
250
+ audio_encoder: AudioEncoder,
251
+ audio_processor: AudioProcessor | None = None,
252
+ ) -> torch.Tensor:
253
+ """Encode audio waveform into latent representation.
254
+ Args:
255
+ audio: Audio container with waveform tensor of shape (batch, channels, samples) and sampling rate.
256
+ audio_encoder: Audio encoder model
257
+ audio_processor: Audio processor model (optional, if not provided, it will be created from the audio encoder)
258
+ """
259
+ dtype = next(audio_encoder.parameters()).dtype
260
+ device = next(audio_encoder.parameters()).device
261
+
262
+ if audio_processor is None:
263
+ audio_processor = AudioProcessor(
264
+ target_sample_rate=audio_encoder.sample_rate,
265
+ mel_bins=audio_encoder.mel_bins,
266
+ mel_hop_length=audio_encoder.mel_hop_length,
267
+ n_fft=audio_encoder.n_fft,
268
+ ).to(device=device)
269
+
270
+ mel_spectrogram = audio_processor.waveform_to_mel(audio.to(device=device))
271
+
272
+ latent = audio_encoder(mel_spectrogram.to(dtype=dtype))
273
+ return latent
274
+
275
+
276
+ class AudioDecoder(torch.nn.Module):
277
+ """
278
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
279
+ The decoder mirrors the encoder structure with configurable channel multipliers,
280
+ attention resolutions, and causal convolutions.
281
+ """
282
+
283
+ def __init__( # noqa: PLR0913
284
+ self,
285
+ *,
286
+ ch: int,
287
+ out_ch: int,
288
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
289
+ num_res_blocks: int,
290
+ attn_resolutions: Set[int],
291
+ resolution: int,
292
+ z_channels: int,
293
+ norm_type: NormType = NormType.GROUP,
294
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
295
+ dropout: float = 0.0,
296
+ mid_block_add_attention: bool = True,
297
+ sample_rate: int = 16000,
298
+ mel_hop_length: int = 160,
299
+ is_causal: bool = True,
300
+ mel_bins: int | None = None,
301
+ ) -> None:
302
+ """
303
+ Initialize the Decoder.
304
+ Args:
305
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
306
+ (audio_vae.model.params.ddconfig):
307
+ - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
308
+ - resolution, z_channels
309
+ - norm_type, causality_axis
310
+ """
311
+ super().__init__()
312
+
313
+ # Internal behavioural defaults that are not driven by the checkpoint.
314
+ resamp_with_conv = True
315
+ attn_type = AttentionType.VANILLA
316
+
317
+ # Per-channel statistics for denormalizing latents
318
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
319
+ self.sample_rate = sample_rate
320
+ self.mel_hop_length = mel_hop_length
321
+ self.is_causal = is_causal
322
+ self.mel_bins = mel_bins
323
+ self.patchifier = AudioPatchifier(
324
+ patch_size=1,
325
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
326
+ sample_rate=sample_rate,
327
+ hop_length=mel_hop_length,
328
+ is_causal=is_causal,
329
+ )
330
+
331
+ self.ch = ch
332
+ self.temb_ch = 0
333
+ self.num_resolutions = len(ch_mult)
334
+ self.num_res_blocks = num_res_blocks
335
+ self.resolution = resolution
336
+ self.out_ch = out_ch
337
+ self.give_pre_end = False
338
+ self.tanh_out = False
339
+ self.norm_type = norm_type
340
+ self.z_channels = z_channels
341
+ self.channel_multipliers = ch_mult
342
+ self.attn_resolutions = attn_resolutions
343
+ self.causality_axis = causality_axis
344
+ self.attn_type = attn_type
345
+
346
+ base_block_channels = ch * self.channel_multipliers[-1]
347
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
348
+ self.z_shape = (1, z_channels, base_resolution, base_resolution)
349
+
350
+ self.conv_in = make_conv2d(
351
+ z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
352
+ )
353
+ self.non_linearity = torch.nn.SiLU()
354
+ self.mid = build_mid_block(
355
+ channels=base_block_channels,
356
+ temb_channels=self.temb_ch,
357
+ dropout=dropout,
358
+ norm_type=self.norm_type,
359
+ causality_axis=self.causality_axis,
360
+ attn_type=self.attn_type,
361
+ add_attention=mid_block_add_attention,
362
+ )
363
+ self.up, final_block_channels = build_upsampling_path(
364
+ ch=ch,
365
+ ch_mult=ch_mult,
366
+ num_resolutions=self.num_resolutions,
367
+ num_res_blocks=num_res_blocks,
368
+ resolution=resolution,
369
+ temb_channels=self.temb_ch,
370
+ dropout=dropout,
371
+ norm_type=self.norm_type,
372
+ causality_axis=self.causality_axis,
373
+ attn_type=self.attn_type,
374
+ attn_resolutions=attn_resolutions,
375
+ resamp_with_conv=resamp_with_conv,
376
+ initial_block_channels=base_block_channels,
377
+ )
378
+
379
+ self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
380
+ self.conv_out = make_conv2d(
381
+ final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
382
+ )
383
+
384
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
385
+ """
386
+ Decode latent features back to audio spectrograms.
387
+ Args:
388
+ sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
389
+ Returns:
390
+ Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
391
+ """
392
+ sample, target_shape = self._denormalize_latents(sample)
393
+
394
+ h = self.conv_in(sample)
395
+ h = run_mid_block(self.mid, h)
396
+ h = self._run_upsampling_path(h)
397
+ h = self._finalize_output(h)
398
+
399
+ return self._adjust_output_shape(h, target_shape)
400
+
401
+ def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
402
+ latent_shape = AudioLatentShape(
403
+ batch=sample.shape[0],
404
+ channels=sample.shape[1],
405
+ frames=sample.shape[2],
406
+ mel_bins=sample.shape[3],
407
+ )
408
+
409
+ sample_patched = self.patchifier.patchify(sample)
410
+ sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
411
+ sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
412
+
413
+ target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
414
+ if self.causality_axis != CausalityAxis.NONE:
415
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
416
+
417
+ target_shape = AudioLatentShape(
418
+ batch=latent_shape.batch,
419
+ channels=self.out_ch,
420
+ frames=target_frames,
421
+ mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
422
+ )
423
+
424
+ return sample, target_shape
425
+
426
+ def _adjust_output_shape(
427
+ self,
428
+ decoded_output: torch.Tensor,
429
+ target_shape: AudioLatentShape,
430
+ ) -> torch.Tensor:
431
+ """
432
+ Adjust output shape to match target dimensions for variable-length audio.
433
+ This function handles the common case where decoded audio spectrograms need to be
434
+ resized to match a specific target shape.
435
+ Args:
436
+ decoded_output: Tensor of shape (batch, channels, time, frequency)
437
+ target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
438
+ Returns:
439
+ Tensor adjusted to match target_shape exactly
440
+ """
441
+ # Current output shape: (batch, channels, time, frequency)
442
+ _, _, current_time, current_freq = decoded_output.shape
443
+ target_channels = target_shape.channels
444
+ target_time = target_shape.frames
445
+ target_freq = target_shape.mel_bins
446
+
447
+ # Step 1: Crop first to avoid exceeding target dimensions
448
+ decoded_output = decoded_output[
449
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
450
+ ]
451
+
452
+ # Step 2: Calculate padding needed for time and frequency dimensions
453
+ time_padding_needed = target_time - decoded_output.shape[2]
454
+ freq_padding_needed = target_freq - decoded_output.shape[3]
455
+
456
+ # Step 3: Apply padding if needed
457
+ if time_padding_needed > 0 or freq_padding_needed > 0:
458
+ # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
459
+ # For audio: pad_left/right = frequency, pad_top/bottom = time
460
+ padding = (
461
+ 0,
462
+ max(freq_padding_needed, 0), # frequency padding (left, right)
463
+ 0,
464
+ max(time_padding_needed, 0), # time padding (top, bottom)
465
+ )
466
+ decoded_output = F.pad(decoded_output, padding)
467
+
468
+ # Step 4: Final safety crop to ensure exact target shape
469
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
470
+
471
+ return decoded_output
472
+
473
+ def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
474
+ for level in reversed(range(self.num_resolutions)):
475
+ stage = self.up[level]
476
+ for block_idx, block in enumerate(stage.block):
477
+ h = block(h, temb=None)
478
+ if stage.attn:
479
+ h = stage.attn[block_idx](h)
480
+
481
+ if level != 0 and hasattr(stage, "upsample"):
482
+ h = stage.upsample(h)
483
+
484
+ return h
485
+
486
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
487
+ if self.give_pre_end:
488
+ return h
489
+
490
+ h = self.norm_out(h)
491
+ h = self.non_linearity(h)
492
+ h = self.conv_out(h)
493
+ return torch.tanh(h) if self.tanh_out else h
494
+
495
+
496
+ def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> Audio:
497
+ """
498
+ Decode an audio latent representation using the provided audio decoder and vocoder.
499
+ Args:
500
+ latent: Input audio latent tensor.
501
+ audio_decoder: Model to decode the latent to waveform features.
502
+ vocoder: Model to convert decoded features to audio waveform.
503
+ Returns:
504
+ Decoded audio with waveform and sampling rate.
505
+ """
506
+ decoded_audio = audio_decoder(latent)
507
+ waveform = vocoder(decoded_audio).squeeze(0).float()
508
+ return Audio(waveform=waveform, sampling_rate=vocoder.output_sampling_rate)
ltx2/ltx_core/model/audio_vae/causal_conv_2d.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
5
+
6
+
7
+ class CausalConv2d(torch.nn.Module):
8
+ """
9
+ A causal 2D convolution.
10
+ This layer ensures that the output at time `t` only depends on inputs
11
+ at time `t` and earlier. It achieves this by applying asymmetric padding
12
+ to the time dimension (width) before the convolution.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ out_channels: int,
19
+ kernel_size: int | tuple[int, int],
20
+ stride: int = 1,
21
+ dilation: int | tuple[int, int] = 1,
22
+ groups: int = 1,
23
+ bias: bool = True,
24
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.causality_axis = causality_axis
29
+
30
+ # Ensure kernel_size and dilation are tuples
31
+ kernel_size = torch.nn.modules.utils._pair(kernel_size)
32
+ dilation = torch.nn.modules.utils._pair(dilation)
33
+
34
+ # Calculate padding dimensions
35
+ pad_h = (kernel_size[0] - 1) * dilation[0]
36
+ pad_w = (kernel_size[1] - 1) * dilation[1]
37
+
38
+ # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
42
+ case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
43
+ self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
44
+ case CausalityAxis.HEIGHT:
45
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
46
+ case _:
47
+ raise ValueError(f"Invalid causality_axis: {causality_axis}")
48
+
49
+ # The internal convolution layer uses no padding, as we handle it manually
50
+ self.conv = torch.nn.Conv2d(
51
+ in_channels,
52
+ out_channels,
53
+ kernel_size,
54
+ stride=stride,
55
+ padding=0,
56
+ dilation=dilation,
57
+ groups=groups,
58
+ bias=bias,
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ # Apply causal padding before convolution
63
+ x = F.pad(x, self.padding)
64
+ return self.conv(x)
65
+
66
+
67
+ def make_conv2d(
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int | tuple[int, int],
71
+ stride: int = 1,
72
+ padding: tuple[int, int, int, int] | None = None,
73
+ dilation: int = 1,
74
+ groups: int = 1,
75
+ bias: bool = True,
76
+ causality_axis: CausalityAxis | None = None,
77
+ ) -> torch.nn.Module:
78
+ """
79
+ Create a 2D convolution layer that can be either causal or non-causal.
80
+ Args:
81
+ in_channels: Number of input channels
82
+ out_channels: Number of output channels
83
+ kernel_size: Size of the convolution kernel
84
+ stride: Convolution stride
85
+ padding: Padding (if None, will be calculated based on causal flag)
86
+ dilation: Dilation rate
87
+ groups: Number of groups for grouped convolution
88
+ bias: Whether to use bias
89
+ causality_axis: Dimension along which to apply causality.
90
+ Returns:
91
+ Either a regular Conv2d or CausalConv2d layer
92
+ """
93
+ if causality_axis is not None:
94
+ # For causal convolution, padding is handled internally by CausalConv2d
95
+ return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
96
+ else:
97
+ # For non-causal convolution, use symmetric padding if not specified
98
+ if padding is None:
99
+ padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
100
+
101
+ return torch.nn.Conv2d(
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size,
105
+ stride,
106
+ padding,
107
+ dilation,
108
+ groups,
109
+ bias,
110
+ )
ltx2/ltx_core/model/audio_vae/causality_axis.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class CausalityAxis(Enum):
5
+ """Enum for specifying the causality axis in causal convolutions."""
6
+
7
+ NONE = None
8
+ WIDTH = "width"
9
+ HEIGHT = "height"
10
+ WIDTH_COMPATIBILITY = "width-compatibility"
ltx2/ltx_core/model/audio_vae/downsample.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
8
+ from ltx_core.model.common.normalization import NormType
9
+
10
+
11
+ class Downsample(torch.nn.Module):
12
+ """
13
+ A downsampling layer that can use either a strided convolution
14
+ or average pooling. Supports standard and causal padding for the
15
+ convolutional mode.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ with_conv: bool,
22
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.with_conv = with_conv
26
+ self.causality_axis = causality_axis
27
+
28
+ if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
29
+ raise ValueError("causality is only supported when `with_conv=True`.")
30
+
31
+ if self.with_conv:
32
+ # Do time downsampling here
33
+ # no asymmetric padding in torch conv, must do it ourselves
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ if self.with_conv:
38
+ # Padding tuple is in the order: (left, right, top, bottom).
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ pad = (0, 1, 0, 1)
42
+ case CausalityAxis.WIDTH:
43
+ pad = (2, 0, 0, 1)
44
+ case CausalityAxis.HEIGHT:
45
+ pad = (0, 1, 2, 0)
46
+ case CausalityAxis.WIDTH_COMPATIBILITY:
47
+ pad = (1, 0, 0, 1)
48
+ case _:
49
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
50
+
51
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
52
+ x = self.conv(x)
53
+ else:
54
+ # This branch is only taken if with_conv=False, which implies causality_axis is NONE.
55
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
56
+
57
+ return x
58
+
59
+
60
+ def build_downsampling_path( # noqa: PLR0913
61
+ *,
62
+ ch: int,
63
+ ch_mult: Tuple[int, ...],
64
+ num_resolutions: int,
65
+ num_res_blocks: int,
66
+ resolution: int,
67
+ temb_channels: int,
68
+ dropout: float,
69
+ norm_type: NormType,
70
+ causality_axis: CausalityAxis,
71
+ attn_type: AttentionType,
72
+ attn_resolutions: Set[int],
73
+ resamp_with_conv: bool,
74
+ ) -> tuple[torch.nn.ModuleList, int]:
75
+ """Build the downsampling path with residual blocks, attention, and downsampling layers."""
76
+ down_modules = torch.nn.ModuleList()
77
+ curr_res = resolution
78
+ in_ch_mult = (1, *tuple(ch_mult))
79
+ block_in = ch
80
+
81
+ for i_level in range(num_resolutions):
82
+ block = torch.nn.ModuleList()
83
+ attn = torch.nn.ModuleList()
84
+ block_in = ch * in_ch_mult[i_level]
85
+ block_out = ch * ch_mult[i_level]
86
+
87
+ for _ in range(num_res_blocks):
88
+ block.append(
89
+ ResnetBlock(
90
+ in_channels=block_in,
91
+ out_channels=block_out,
92
+ temb_channels=temb_channels,
93
+ dropout=dropout,
94
+ norm_type=norm_type,
95
+ causality_axis=causality_axis,
96
+ )
97
+ )
98
+ block_in = block_out
99
+ if curr_res in attn_resolutions:
100
+ attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
101
+
102
+ down = torch.nn.Module()
103
+ down.block = block
104
+ down.attn = attn
105
+ if i_level != num_resolutions - 1:
106
+ down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
107
+ curr_res = curr_res // 2
108
+ down_modules.append(down)
109
+
110
+ return down_modules, block_in
ltx2/ltx_core/model/audio_vae/model_configurator.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps
4
+ from ltx_core.model.audio_vae.attention import AttentionType
5
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.audio_vae.vocoder import MelSTFT, Vocoder, VocoderWithBWE
8
+ from ltx_core.model.common.normalization import NormType
9
+ from ltx_core.model.model_protocol import ModelConfigurator
10
+ from ltx_core.utils import check_config_value
11
+
12
+
13
+ def _vocoder_from_config(
14
+ cfg: dict,
15
+ apply_final_activation: bool = True,
16
+ output_sampling_rate: int | None = None,
17
+ ) -> Vocoder:
18
+ """Instantiate a Vocoder from a flat config dict.
19
+ Args:
20
+ cfg: Vocoder config dict (keys match Vocoder constructor args).
21
+ apply_final_activation: Whether to apply tanh/clamp at the output.
22
+ output_sampling_rate: Explicit override for the output sample rate.
23
+ When None, reads from cfg["output_sampling_rate"] (default 24000).
24
+ """
25
+ return Vocoder(
26
+ resblock_kernel_sizes=cfg.get("resblock_kernel_sizes", [3, 7, 11]),
27
+ upsample_rates=cfg.get("upsample_rates", [6, 5, 2, 2, 2]),
28
+ upsample_kernel_sizes=cfg.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
29
+ resblock_dilation_sizes=cfg.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
30
+ upsample_initial_channel=cfg.get("upsample_initial_channel", 1024),
31
+ resblock=cfg.get("resblock", "1"),
32
+ output_sampling_rate=(
33
+ output_sampling_rate if output_sampling_rate is not None else cfg.get("output_sampling_rate", 24000)
34
+ ),
35
+ activation=cfg.get("activation", "snake"),
36
+ use_tanh_at_final=cfg.get("use_tanh_at_final", True),
37
+ apply_final_activation=apply_final_activation,
38
+ use_bias_at_final=cfg.get("use_bias_at_final", True),
39
+ )
40
+
41
+
42
+ class VocoderConfigurator(ModelConfigurator[Vocoder]):
43
+ """Configurator that auto-detects the checkpoint format.
44
+ Returns a plain Vocoder for pre-ltx-2.3 checkpoints (flat config) or a
45
+ VocoderWithBWE for ltx-2.3+ checkpoints (nested "vocoder" + "bwe" config).
46
+ """
47
+
48
+ @classmethod
49
+ def from_config(cls: type[Vocoder], config: dict) -> Vocoder | VocoderWithBWE:
50
+ cfg = config.get("vocoder", {})
51
+
52
+ if "bwe" not in cfg:
53
+ check_config_value(cfg, "resblock", "1")
54
+ check_config_value(cfg, "stereo", True)
55
+ return _vocoder_from_config(cfg)
56
+
57
+ vocoder_cfg = cfg.get("vocoder", {})
58
+ bwe_cfg = cfg["bwe"]
59
+
60
+ check_config_value(vocoder_cfg, "resblock", "AMP1")
61
+ check_config_value(vocoder_cfg, "stereo", True)
62
+ check_config_value(vocoder_cfg, "activation", "snakebeta")
63
+ check_config_value(bwe_cfg, "resblock", "AMP1")
64
+ check_config_value(bwe_cfg, "stereo", True)
65
+ check_config_value(bwe_cfg, "activation", "snakebeta")
66
+
67
+ vocoder = _vocoder_from_config(
68
+ vocoder_cfg,
69
+ output_sampling_rate=bwe_cfg["input_sampling_rate"],
70
+ )
71
+ bwe_generator = _vocoder_from_config(
72
+ bwe_cfg,
73
+ apply_final_activation=False,
74
+ output_sampling_rate=bwe_cfg["output_sampling_rate"],
75
+ )
76
+ mel_stft = MelSTFT(
77
+ filter_length=bwe_cfg["n_fft"],
78
+ hop_length=bwe_cfg["hop_length"],
79
+ win_length=bwe_cfg["n_fft"],
80
+ n_mel_channels=bwe_cfg["num_mels"],
81
+ )
82
+ return VocoderWithBWE(
83
+ vocoder=vocoder,
84
+ bwe_generator=bwe_generator,
85
+ mel_stft=mel_stft,
86
+ input_sampling_rate=bwe_cfg["input_sampling_rate"],
87
+ output_sampling_rate=bwe_cfg["output_sampling_rate"],
88
+ hop_length=bwe_cfg["hop_length"],
89
+ )
90
+
91
+
92
+ def _strip_vocoder_prefix(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
93
+ """Strip the leading 'vocoder.' prefix exactly once.
94
+ Uses removeprefix instead of str.replace so that BWE keys like
95
+ 'vocoder.vocoder.conv_pre' become 'vocoder.conv_pre' (not 'conv_pre').
96
+ Works identically for legacy keys like 'vocoder.conv_pre' → 'conv_pre'.
97
+ """
98
+ return [KeyValueOperationResult(key.removeprefix("vocoder."), value)]
99
+
100
+
101
+ VOCODER_COMFY_KEYS_FILTER = (
102
+ SDOps("VOCODER_COMFY_KEYS_FILTER")
103
+ .with_matching(prefix="vocoder.")
104
+ .with_kv_operation(operation=_strip_vocoder_prefix, key_prefix="vocoder.")
105
+ )
106
+
107
+
108
+ class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
109
+ @classmethod
110
+ def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
111
+ audio_vae_cfg = config.get("audio_vae", {})
112
+ model_cfg = audio_vae_cfg.get("model", {})
113
+ model_params = model_cfg.get("params", {})
114
+ ddconfig = model_params.get("ddconfig", {})
115
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
116
+ stft_cfg = preprocessing_cfg.get("stft", {})
117
+ mel_cfg = preprocessing_cfg.get("mel", {})
118
+ variables_cfg = audio_vae_cfg.get("variables", {})
119
+
120
+ sample_rate = model_params.get("sampling_rate", 16000)
121
+ mel_hop_length = stft_cfg.get("hop_length", 160)
122
+ is_causal = stft_cfg.get("causal", True)
123
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
124
+
125
+ return AudioDecoder(
126
+ ch=ddconfig.get("ch", 128),
127
+ out_ch=ddconfig.get("out_ch", 2),
128
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
129
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
130
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
131
+ resolution=ddconfig.get("resolution", 256),
132
+ z_channels=ddconfig.get("z_channels", 8),
133
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
134
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
135
+ dropout=ddconfig.get("dropout", 0.0),
136
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
137
+ sample_rate=sample_rate,
138
+ mel_hop_length=mel_hop_length,
139
+ is_causal=is_causal,
140
+ mel_bins=mel_bins,
141
+ )
142
+
143
+
144
+ class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
145
+ @classmethod
146
+ def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
147
+ audio_vae_cfg = config.get("audio_vae", {})
148
+ model_cfg = audio_vae_cfg.get("model", {})
149
+ model_params = model_cfg.get("params", {})
150
+ ddconfig = model_params.get("ddconfig", {})
151
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
152
+ stft_cfg = preprocessing_cfg.get("stft", {})
153
+ mel_cfg = preprocessing_cfg.get("mel", {})
154
+ variables_cfg = audio_vae_cfg.get("variables", {})
155
+
156
+ sample_rate = model_params.get("sampling_rate", 16000)
157
+ mel_hop_length = stft_cfg.get("hop_length", 160)
158
+ n_fft = stft_cfg.get("filter_length", 1024)
159
+ is_causal = stft_cfg.get("causal", True)
160
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
161
+
162
+ return AudioEncoder(
163
+ ch=ddconfig.get("ch", 128),
164
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
165
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
166
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
167
+ resolution=ddconfig.get("resolution", 256),
168
+ z_channels=ddconfig.get("z_channels", 8),
169
+ double_z=ddconfig.get("double_z", True),
170
+ dropout=ddconfig.get("dropout", 0.0),
171
+ resamp_with_conv=ddconfig.get("resamp_with_conv", True),
172
+ in_channels=ddconfig.get("in_channels", 2),
173
+ attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
174
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
175
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
176
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
177
+ sample_rate=sample_rate,
178
+ mel_hop_length=mel_hop_length,
179
+ n_fft=n_fft,
180
+ is_causal=is_causal,
181
+ mel_bins=mel_bins,
182
+ )
183
+
184
+
185
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
186
+ SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
187
+ .with_matching(prefix="audio_vae.decoder.")
188
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
189
+ .with_replacement("audio_vae.decoder.", "")
190
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
191
+ )
192
+
193
+
194
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
195
+ SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
196
+ .with_matching(prefix="audio_vae.encoder.")
197
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
198
+ .with_replacement("audio_vae.encoder.", "")
199
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
200
+ )
ltx2/ltx_core/model/audio_vae/ops.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torch import nn
4
+
5
+ from ltx_core.types import Audio
6
+
7
+
8
+ class AudioProcessor(nn.Module):
9
+ """Converts audio waveforms to log-mel spectrograms with optional resampling."""
10
+
11
+ def __init__(
12
+ self,
13
+ target_sample_rate: int,
14
+ mel_bins: int,
15
+ mel_hop_length: int,
16
+ n_fft: int,
17
+ ) -> None:
18
+ super().__init__()
19
+ self.target_sample_rate = target_sample_rate
20
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
21
+ sample_rate=target_sample_rate,
22
+ n_fft=n_fft,
23
+ win_length=n_fft,
24
+ hop_length=mel_hop_length,
25
+ f_min=0.0,
26
+ f_max=target_sample_rate / 2.0,
27
+ n_mels=mel_bins,
28
+ window_fn=torch.hann_window,
29
+ center=True,
30
+ pad_mode="reflect",
31
+ power=1.0,
32
+ mel_scale="slaney",
33
+ norm="slaney",
34
+ )
35
+
36
+ def resample_audio(self, audio: Audio) -> Audio:
37
+ """Resample audio to the processor's target sample rate if needed."""
38
+ if audio.sampling_rate == self.target_sample_rate:
39
+ return audio
40
+ resampled = torchaudio.functional.resample(audio.waveform, audio.sampling_rate, self.target_sample_rate)
41
+ resampled = resampled.to(device=audio.waveform.device, dtype=audio.waveform.dtype)
42
+ return Audio(waveform=resampled, sampling_rate=self.target_sample_rate)
43
+
44
+ def waveform_to_mel(
45
+ self,
46
+ audio: Audio,
47
+ ) -> torch.Tensor:
48
+ """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
49
+ waveform = self.resample_audio(audio).waveform
50
+
51
+ mel = self.mel_transform(waveform)
52
+ mel = torch.log(torch.clamp(mel, min=1e-5))
53
+
54
+ mel = mel.to(device=waveform.device, dtype=waveform.dtype)
55
+ return mel.permute(0, 1, 3, 2).contiguous()
56
+
57
+
58
+ class PerChannelStatistics(nn.Module):
59
+ """
60
+ Per-channel statistics for normalizing and denormalizing the latent representation.
61
+ This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
62
+ """
63
+
64
+ def __init__(self, latent_channels: int = 128) -> None:
65
+ super().__init__()
66
+ self.register_buffer("std-of-means", torch.empty(latent_channels))
67
+ self.register_buffer("mean-of-means", torch.empty(latent_channels))
68
+
69
+ def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
70
+ return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
71
+
72
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
73
+ return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
ltx2/ltx_core/model/audio_vae/resnet.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ class ResBlock1(torch.nn.Module):
13
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
14
+ super(ResBlock1, self).__init__()
15
+ self.convs1 = torch.nn.ModuleList(
16
+ [
17
+ torch.nn.Conv1d(
18
+ channels,
19
+ channels,
20
+ kernel_size,
21
+ 1,
22
+ dilation=dilation[0],
23
+ padding="same",
24
+ ),
25
+ torch.nn.Conv1d(
26
+ channels,
27
+ channels,
28
+ kernel_size,
29
+ 1,
30
+ dilation=dilation[1],
31
+ padding="same",
32
+ ),
33
+ torch.nn.Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[2],
39
+ padding="same",
40
+ ),
41
+ ]
42
+ )
43
+
44
+ self.convs2 = torch.nn.ModuleList(
45
+ [
46
+ torch.nn.Conv1d(
47
+ channels,
48
+ channels,
49
+ kernel_size,
50
+ 1,
51
+ dilation=1,
52
+ padding="same",
53
+ ),
54
+ torch.nn.Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=1,
60
+ padding="same",
61
+ ),
62
+ torch.nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ 1,
67
+ dilation=1,
68
+ padding="same",
69
+ ),
70
+ ]
71
+ )
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
75
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
76
+ xt = conv1(xt)
77
+ xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
78
+ xt = conv2(xt)
79
+ x = xt + x
80
+ return x
81
+
82
+
83
+ class ResBlock2(torch.nn.Module):
84
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
85
+ super(ResBlock2, self).__init__()
86
+ self.convs = torch.nn.ModuleList(
87
+ [
88
+ torch.nn.Conv1d(
89
+ channels,
90
+ channels,
91
+ kernel_size,
92
+ 1,
93
+ dilation=dilation[0],
94
+ padding="same",
95
+ ),
96
+ torch.nn.Conv1d(
97
+ channels,
98
+ channels,
99
+ kernel_size,
100
+ 1,
101
+ dilation=dilation[1],
102
+ padding="same",
103
+ ),
104
+ ]
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ for conv in self.convs:
109
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
110
+ xt = conv(xt)
111
+ x = xt + x
112
+ return x
113
+
114
+
115
+ class ResnetBlock(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ *,
119
+ in_channels: int,
120
+ out_channels: int | None = None,
121
+ conv_shortcut: bool = False,
122
+ dropout: float = 0.0,
123
+ temb_channels: int = 512,
124
+ norm_type: NormType = NormType.GROUP,
125
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
126
+ ) -> None:
127
+ super().__init__()
128
+ self.causality_axis = causality_axis
129
+
130
+ if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
131
+ raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
132
+ self.in_channels = in_channels
133
+ out_channels = in_channels if out_channels is None else out_channels
134
+ self.out_channels = out_channels
135
+ self.use_conv_shortcut = conv_shortcut
136
+
137
+ self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
138
+ self.non_linearity = torch.nn.SiLU()
139
+ self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
140
+ if temb_channels > 0:
141
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
142
+ self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
143
+ self.dropout = torch.nn.Dropout(dropout)
144
+ self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = make_conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
149
+ )
150
+ else:
151
+ self.nin_shortcut = make_conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ temb: torch.Tensor | None = None,
159
+ ) -> torch.Tensor:
160
+ h = x
161
+ h = self.norm1(h)
162
+ h = self.non_linearity(h)
163
+ h = self.conv1(h)
164
+
165
+ if temb is not None:
166
+ h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
167
+
168
+ h = self.norm2(h)
169
+ h = self.non_linearity(h)
170
+ h = self.dropout(h)
171
+ h = self.conv2(h)
172
+
173
+ if self.in_channels != self.out_channels:
174
+ x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
175
+
176
+ return x + h
ltx2/ltx_core/model/audio_vae/upsample.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
7
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
8
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
9
+ from ltx_core.model.common.normalization import NormType
10
+
11
+
12
+ class Upsample(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ with_conv: bool,
17
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.with_conv = with_conv
21
+ self.causality_axis = causality_axis
22
+ if self.with_conv:
23
+ self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
27
+ if self.with_conv:
28
+ x = self.conv(x)
29
+ # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
30
+ # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
31
+ # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
32
+ # So the output elements rely on the following windows:
33
+ # 0: [-,-,0]
34
+ # 1: [-,0,0]
35
+ # 2: [0,0,1]
36
+ # 3: [0,1,1]
37
+ # 4: [1,1,2]
38
+ # 5: [1,2,2]
39
+ # Notice that the first and second elements in the output rely only on the first element in the input,
40
+ # while all other elements rely on two elements in the input.
41
+ # So we can drop the first element to undo the padding (rather than the last element).
42
+ # This is a no-op for non-causal convolutions.
43
+ match self.causality_axis:
44
+ case CausalityAxis.NONE:
45
+ pass # x remains unchanged
46
+ case CausalityAxis.HEIGHT:
47
+ x = x[:, :, 1:, :]
48
+ case CausalityAxis.WIDTH:
49
+ x = x[:, :, :, 1:]
50
+ case CausalityAxis.WIDTH_COMPATIBILITY:
51
+ pass # x remains unchanged
52
+ case _:
53
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
54
+
55
+ return x
56
+
57
+
58
+ def build_upsampling_path( # noqa: PLR0913
59
+ *,
60
+ ch: int,
61
+ ch_mult: Tuple[int, ...],
62
+ num_resolutions: int,
63
+ num_res_blocks: int,
64
+ resolution: int,
65
+ temb_channels: int,
66
+ dropout: float,
67
+ norm_type: NormType,
68
+ causality_axis: CausalityAxis,
69
+ attn_type: AttentionType,
70
+ attn_resolutions: Set[int],
71
+ resamp_with_conv: bool,
72
+ initial_block_channels: int,
73
+ ) -> tuple[torch.nn.ModuleList, int]:
74
+ """Build the upsampling path with residual blocks, attention, and upsampling layers."""
75
+ up_modules = torch.nn.ModuleList()
76
+ block_in = initial_block_channels
77
+ curr_res = resolution // (2 ** (num_resolutions - 1))
78
+
79
+ for level in reversed(range(num_resolutions)):
80
+ stage = torch.nn.Module()
81
+ stage.block = torch.nn.ModuleList()
82
+ stage.attn = torch.nn.ModuleList()
83
+ block_out = ch * ch_mult[level]
84
+
85
+ for _ in range(num_res_blocks + 1):
86
+ stage.block.append(
87
+ ResnetBlock(
88
+ in_channels=block_in,
89
+ out_channels=block_out,
90
+ temb_channels=temb_channels,
91
+ dropout=dropout,
92
+ norm_type=norm_type,
93
+ causality_axis=causality_axis,
94
+ )
95
+ )
96
+ block_in = block_out
97
+ if curr_res in attn_resolutions:
98
+ stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
99
+
100
+ if level != 0:
101
+ stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
102
+ curr_res *= 2
103
+
104
+ up_modules.insert(0, stage)
105
+
106
+ return up_modules, block_in
ltx2/ltx_core/model/audio_vae/vocoder.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import einops
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1
10
+
11
+
12
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
18
+ # Adopted from https://github.com/NVIDIA/BigVGAN
19
+ # ---------------------------------------------------------------------------
20
+
21
+
22
+ def _sinc(x: torch.Tensor) -> torch.Tensor:
23
+ return torch.where(
24
+ x == 0,
25
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
26
+ torch.sin(math.pi * x) / math.pi / x,
27
+ )
28
+
29
+
30
+ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
31
+ even = kernel_size % 2 == 0
32
+ half_size = kernel_size // 2
33
+ delta_f = 4 * half_width
34
+ amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if amplitude > 50.0:
36
+ beta = 0.1102 * (amplitude - 8.7)
37
+ elif amplitude >= 21.0:
38
+ beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
39
+ else:
40
+ beta = 0.0
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+ time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
43
+ if cutoff == 0:
44
+ filter_ = torch.zeros_like(time)
45
+ else:
46
+ filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
47
+ filter_ /= filter_.sum()
48
+ return filter_.view(1, 1, kernel_size)
49
+
50
+
51
+ class LowPassFilter1d(nn.Module):
52
+ def __init__(
53
+ self,
54
+ cutoff: float = 0.5,
55
+ half_width: float = 0.6,
56
+ stride: int = 1,
57
+ padding: bool = True,
58
+ padding_mode: str = "replicate",
59
+ kernel_size: int = 12,
60
+ ) -> None:
61
+ super().__init__()
62
+ if cutoff < -0.0:
63
+ raise ValueError("Minimum cutoff must be larger than zero.")
64
+ if cutoff > 0.5:
65
+ raise ValueError("A cutoff above 0.5 does not make sense.")
66
+ self.kernel_size = kernel_size
67
+ self.even = kernel_size % 2 == 0
68
+ self.pad_left = kernel_size // 2 - int(self.even)
69
+ self.pad_right = kernel_size // 2
70
+ self.stride = stride
71
+ self.padding = padding
72
+ self.padding_mode = padding_mode
73
+ self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ _, n_channels, _ = x.shape
77
+ if self.padding:
78
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
79
+ return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
80
+
81
+
82
+ class UpSample1d(nn.Module):
83
+ def __init__(
84
+ self,
85
+ ratio: int = 2,
86
+ kernel_size: int | None = None,
87
+ persistent: bool = True,
88
+ window_type: str = "kaiser",
89
+ ) -> None:
90
+ super().__init__()
91
+ self.ratio = ratio
92
+ self.stride = ratio
93
+
94
+ if window_type == "hann":
95
+ # Hann-windowed sinc filter equivalent to torchaudio.functional.resample
96
+ rolloff = 0.99
97
+ lowpass_filter_width = 6
98
+ width = math.ceil(lowpass_filter_width / rolloff)
99
+ self.kernel_size = 2 * width * ratio + 1
100
+ self.pad = width
101
+ self.pad_left = 2 * width * ratio
102
+ self.pad_right = self.kernel_size - ratio
103
+ time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
104
+ time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
105
+ window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
106
+ sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
107
+ else:
108
+ # Kaiser-windowed sinc filter (BigVGAN default).
109
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
110
+ self.pad = self.kernel_size // ratio - 1
111
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
112
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
113
+ sinc_filter = kaiser_sinc_filter1d(
114
+ cutoff=0.5 / ratio,
115
+ half_width=0.6 / ratio,
116
+ kernel_size=self.kernel_size,
117
+ )
118
+
119
+ self.register_buffer("filter", sinc_filter, persistent=persistent)
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ _, n_channels, _ = x.shape
123
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
124
+ filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
125
+ x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
126
+ return x[..., self.pad_left : -self.pad_right]
127
+
128
+
129
+ class DownSample1d(nn.Module):
130
+ def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
131
+ super().__init__()
132
+ self.ratio = ratio
133
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
134
+ self.lowpass = LowPassFilter1d(
135
+ cutoff=0.5 / ratio,
136
+ half_width=0.6 / ratio,
137
+ stride=ratio,
138
+ kernel_size=self.kernel_size,
139
+ )
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ return self.lowpass(x)
143
+
144
+
145
+ class Activation1d(nn.Module):
146
+ def __init__(
147
+ self,
148
+ activation: nn.Module,
149
+ up_ratio: int = 2,
150
+ down_ratio: int = 2,
151
+ up_kernel_size: int = 12,
152
+ down_kernel_size: int = 12,
153
+ ) -> None:
154
+ super().__init__()
155
+ self.act = activation
156
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
157
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ x = self.upsample(x)
161
+ x = self.act(x)
162
+ return self.downsample(x)
163
+
164
+
165
+ class Snake(nn.Module):
166
+ def __init__(
167
+ self,
168
+ in_features: int,
169
+ alpha: float = 1.0,
170
+ alpha_trainable: bool = True,
171
+ alpha_logscale: bool = True,
172
+ ) -> None:
173
+ super().__init__()
174
+ self.alpha_logscale = alpha_logscale
175
+ self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
176
+ self.alpha.requires_grad = alpha_trainable
177
+ self.eps = 1e-9
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
181
+ if self.alpha_logscale:
182
+ alpha = torch.exp(alpha)
183
+ return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
184
+
185
+
186
+ class SnakeBeta(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_features: int,
190
+ alpha: float = 1.0,
191
+ alpha_trainable: bool = True,
192
+ alpha_logscale: bool = True,
193
+ ) -> None:
194
+ super().__init__()
195
+ self.alpha_logscale = alpha_logscale
196
+ self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
197
+ self.alpha.requires_grad = alpha_trainable
198
+ self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
199
+ self.beta.requires_grad = alpha_trainable
200
+ self.eps = 1e-9
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
204
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
205
+ if self.alpha_logscale:
206
+ alpha = torch.exp(alpha)
207
+ beta = torch.exp(beta)
208
+ return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
209
+
210
+
211
+ class AMPBlock1(nn.Module):
212
+ def __init__(
213
+ self,
214
+ channels: int,
215
+ kernel_size: int = 3,
216
+ dilation: tuple[int, int, int] = (1, 3, 5),
217
+ activation: str = "snake",
218
+ ) -> None:
219
+ super().__init__()
220
+ act_cls = SnakeBeta if activation == "snakebeta" else Snake
221
+ self.convs1 = nn.ModuleList(
222
+ [
223
+ nn.Conv1d(
224
+ channels,
225
+ channels,
226
+ kernel_size,
227
+ 1,
228
+ dilation=dilation[0],
229
+ padding=get_padding(kernel_size, dilation[0]),
230
+ ),
231
+ nn.Conv1d(
232
+ channels,
233
+ channels,
234
+ kernel_size,
235
+ 1,
236
+ dilation=dilation[1],
237
+ padding=get_padding(kernel_size, dilation[1]),
238
+ ),
239
+ nn.Conv1d(
240
+ channels,
241
+ channels,
242
+ kernel_size,
243
+ 1,
244
+ dilation=dilation[2],
245
+ padding=get_padding(kernel_size, dilation[2]),
246
+ ),
247
+ ]
248
+ )
249
+
250
+ self.convs2 = nn.ModuleList(
251
+ [
252
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
253
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
254
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
255
+ ]
256
+ )
257
+
258
+ self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
259
+ self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
260
+
261
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
262
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
263
+ xt = a1(x)
264
+ xt = c1(xt)
265
+ xt = a2(xt)
266
+ xt = c2(xt)
267
+ x = x + xt
268
+ return x
269
+
270
+
271
+ class Vocoder(torch.nn.Module):
272
+ """
273
+ Vocoder model for synthesizing audio from Mel spectrograms.
274
+ Args:
275
+ resblock_kernel_sizes: List of kernel sizes for the residual blocks.
276
+ This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
277
+ upsample_rates: List of upsampling rates.
278
+ This value is read from the checkpoint at `config.vocoder.upsample_rates`.
279
+ upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
280
+ This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
281
+ resblock_dilation_sizes: List of dilation sizes for the residual blocks.
282
+ This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
283
+ upsample_initial_channel: Initial number of channels for the upsampling layers.
284
+ This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
285
+ resblock: Type of residual block to use ("1", "2", or "AMP1").
286
+ This value is read from the checkpoint at `config.vocoder.resblock`.
287
+ output_sampling_rate: Waveform sample rate.
288
+ This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
289
+ activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
290
+ use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
291
+ apply_final_activation: Whether to apply the final tanh/clamp activation.
292
+ use_bias_at_final: Whether to use bias in the final conv layer.
293
+ """
294
+
295
+ def __init__( # noqa: PLR0913
296
+ self,
297
+ resblock_kernel_sizes: List[int] | None = None,
298
+ upsample_rates: List[int] | None = None,
299
+ upsample_kernel_sizes: List[int] | None = None,
300
+ resblock_dilation_sizes: List[List[int]] | None = None,
301
+ upsample_initial_channel: int = 1024,
302
+ resblock: str = "1",
303
+ output_sampling_rate: int = 24000,
304
+ activation: str = "snake",
305
+ use_tanh_at_final: bool = True,
306
+ apply_final_activation: bool = True,
307
+ use_bias_at_final: bool = True,
308
+ ) -> None:
309
+ super().__init__()
310
+
311
+ # Mutable default values are not supported as default arguments.
312
+ if resblock_kernel_sizes is None:
313
+ resblock_kernel_sizes = [3, 7, 11]
314
+ if upsample_rates is None:
315
+ upsample_rates = [6, 5, 2, 2, 2]
316
+ if upsample_kernel_sizes is None:
317
+ upsample_kernel_sizes = [16, 15, 8, 4, 4]
318
+ if resblock_dilation_sizes is None:
319
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
320
+
321
+ self.output_sampling_rate = output_sampling_rate
322
+ self.num_kernels = len(resblock_kernel_sizes)
323
+ self.num_upsamples = len(upsample_rates)
324
+ self.use_tanh_at_final = use_tanh_at_final
325
+ self.apply_final_activation = apply_final_activation
326
+ self.is_amp = resblock == "AMP1"
327
+
328
+ # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
329
+ # bins each), 2 output channels.
330
+ self.conv_pre = nn.Conv1d(
331
+ in_channels=128,
332
+ out_channels=upsample_initial_channel,
333
+ kernel_size=7,
334
+ stride=1,
335
+ padding=3,
336
+ )
337
+ resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
338
+
339
+ self.ups = nn.ModuleList(
340
+ nn.ConvTranspose1d(
341
+ upsample_initial_channel // (2**i),
342
+ upsample_initial_channel // (2 ** (i + 1)),
343
+ kernel_size,
344
+ stride,
345
+ padding=(kernel_size - stride) // 2,
346
+ )
347
+ for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
348
+ )
349
+
350
+ final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
351
+ self.resblocks = nn.ModuleList()
352
+
353
+ for i in range(len(upsample_rates)):
354
+ ch = upsample_initial_channel // (2 ** (i + 1))
355
+ for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
356
+ if self.is_amp:
357
+ self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
358
+ else:
359
+ self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
360
+
361
+ if self.is_amp:
362
+ self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
363
+ else:
364
+ self.act_post = nn.LeakyReLU()
365
+
366
+ # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
367
+ self.conv_post = nn.Conv1d(
368
+ in_channels=final_channels,
369
+ out_channels=2,
370
+ kernel_size=7,
371
+ stride=1,
372
+ padding=3,
373
+ bias=use_bias_at_final,
374
+ )
375
+
376
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
377
+ """
378
+ Forward pass of the vocoder.
379
+ Args:
380
+ x: Input Mel spectrogram tensor. Can be either:
381
+ - 3D: (batch_size, time, mel_bins) for mono
382
+ - 4D: (batch_size, 2, time, mel_bins) for stereo
383
+ Returns:
384
+ Audio waveform tensor of shape (batch_size, out_channels, audio_length)
385
+ """
386
+ x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
387
+
388
+ if x.dim() == 4: # stereo
389
+ assert x.shape[1] == 2, "Input must have 2 channels for stereo"
390
+ x = einops.rearrange(x, "b s c t -> b (s c) t")
391
+
392
+ x = self.conv_pre(x)
393
+
394
+ for i in range(self.num_upsamples):
395
+ if not self.is_amp:
396
+ x = F.leaky_relu(x, LRELU_SLOPE)
397
+ x = self.ups[i](x)
398
+ start = i * self.num_kernels
399
+ end = start + self.num_kernels
400
+
401
+ # Evaluate all resblocks with the same input tensor so they can run
402
+ # independently (and thus in parallel on accelerator hardware) before
403
+ # aggregating their outputs via mean.
404
+ block_outputs = torch.stack(
405
+ [self.resblocks[idx](x) for idx in range(start, end)],
406
+ dim=0,
407
+ )
408
+ x = block_outputs.mean(dim=0)
409
+
410
+ x = self.act_post(x)
411
+ x = self.conv_post(x)
412
+
413
+ if self.apply_final_activation:
414
+ x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
415
+
416
+ return x
417
+
418
+
419
+ class _STFTFn(nn.Module):
420
+ """Implements STFT as a convolution with precomputed DFT x Hann-window bases.
421
+ The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
422
+ Hann window are stored as buffers and loaded from the checkpoint. Using the exact
423
+ bfloat16 bases from training ensures the mel values fed to the BWE generator are
424
+ bit-identical to what it was trained on.
425
+ """
426
+
427
+ def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
428
+ super().__init__()
429
+ self.hop_length = hop_length
430
+ self.win_length = win_length
431
+ n_freqs = filter_length // 2 + 1
432
+ self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
433
+ self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
434
+
435
+ def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
436
+ """Compute magnitude and phase spectrogram from a batch of waveforms.
437
+ Applies causal (left-only) padding of win_length - hop_length samples so that
438
+ each output frame depends only on past and present input — no lookahead.
439
+ Args:
440
+ y: Waveform tensor of shape (B, T).
441
+ Returns:
442
+ magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
443
+ phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
444
+ """
445
+ if y.dim() == 2:
446
+ y = y.unsqueeze(1) # (B, 1, T)
447
+ left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
448
+ y = F.pad(y, (left_pad, 0))
449
+ spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
450
+ n_freqs = spec.shape[1] // 2
451
+ real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
452
+ magnitude = torch.sqrt(real**2 + imag**2)
453
+ phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
454
+ return magnitude, phase
455
+
456
+
457
+ class MelSTFT(nn.Module):
458
+ """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
459
+ Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
460
+ waveform and projecting the linear magnitude spectrum onto the mel filterbank.
461
+ The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
462
+ (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
463
+ """
464
+
465
+ def __init__(
466
+ self,
467
+ filter_length: int,
468
+ hop_length: int,
469
+ win_length: int,
470
+ n_mel_channels: int,
471
+ ) -> None:
472
+ super().__init__()
473
+ self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
474
+
475
+ # Initialized to zeros; load_state_dict overwrites with the checkpoint's
476
+ # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
477
+ n_freqs = filter_length // 2 + 1
478
+ self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
479
+
480
+ def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
481
+ """Compute log-mel spectrogram and auxiliary spectral quantities.
482
+ Args:
483
+ y: Waveform tensor of shape (B, T).
484
+ Returns:
485
+ log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
486
+ magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
487
+ phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
488
+ energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
489
+ """
490
+ magnitude, phase = self.stft_fn(y)
491
+ energy = torch.norm(magnitude, dim=1)
492
+ mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
493
+ log_mel = torch.log(torch.clamp(mel, min=1e-5))
494
+ return log_mel, magnitude, phase, energy
495
+
496
+
497
+ class VocoderWithBWE(nn.Module):
498
+ """Vocoder with bandwidth extension (BWE) upsampling.
499
+ Chains a mel-to-wav vocoder with a BWE module that upsamples the output
500
+ to a higher sample rate. The BWE computes a mel spectrogram from the
501
+ vocoder output, runs it through a second generator to predict a residual,
502
+ and adds it to a sinc-resampled skip connection.
503
+ The forward pass runs in fp32 via autocast to avoid bfloat16 accumulation
504
+ errors that degrade spectral metrics by 40-90%.
505
+ """
506
+
507
+ def __init__(
508
+ self,
509
+ vocoder: Vocoder,
510
+ bwe_generator: Vocoder,
511
+ mel_stft: MelSTFT,
512
+ input_sampling_rate: int,
513
+ output_sampling_rate: int,
514
+ hop_length: int,
515
+ ) -> None:
516
+ super().__init__()
517
+ self.vocoder = vocoder
518
+ self.bwe_generator = bwe_generator
519
+ self.mel_stft = mel_stft
520
+ self.input_sampling_rate = input_sampling_rate
521
+ self.output_sampling_rate = output_sampling_rate
522
+ self.hop_length = hop_length
523
+ # Compute the resampler on CPU so the sinc filter is materialized even when
524
+ # the model is constructed on meta device (SingleGPUModelBuilder pattern).
525
+ # The filter is not stored in the checkpoint (persistent=False).
526
+ with torch.device("cpu"):
527
+ self.resampler = UpSample1d(
528
+ ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
529
+ )
530
+
531
+ @property
532
+ def conv_pre(self) -> nn.Conv1d:
533
+ return self.vocoder.conv_pre
534
+
535
+ @property
536
+ def conv_post(self) -> nn.Conv1d:
537
+ return self.vocoder.conv_post
538
+
539
+ def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
540
+ """Compute log-mel spectrogram from waveform using causal STFT bases.
541
+ Args:
542
+ audio: Waveform tensor of shape (B, C, T).
543
+ Returns:
544
+ mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
545
+ """
546
+ batch, n_channels, _ = audio.shape
547
+ flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
548
+ mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
549
+ return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
550
+
551
+ def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
552
+ """Run the full vocoder + BWE forward pass.
553
+ Runs in float32 regardless of weight or input dtype. bfloat16 arithmetic
554
+ causes 40-90% spectral metric degradation due to accumulation errors
555
+ compounding through 108 sequential convolutions in the BigVGAN v2 architecture.
556
+ Args:
557
+ mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
558
+ or (B, T, mel_bins) for mono. Same format as Vocoder.forward.
559
+ Returns:
560
+ Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
561
+ """
562
+ input_dtype = mel_spec.dtype
563
+ # Run the entire forward pass in fp32. bfloat16 accumulation errors
564
+ # compound through 108 sequential convolutions and degrade spectral
565
+ # metrics (mel_l1, MRSTFT) by 40-90% while perceptual quality (CDPAM)
566
+ # is unaffected. fp32 eliminates this degradation.
567
+ # We use autocast(dtype=float32) rather than self.float() because it
568
+ # upcasts bf16 weights per-op at kernel level, avoiding the temporary
569
+ # memory spike of self.float() / self.to(original_dtype).
570
+ # Benchmarked on H100 (128.5M-param model):
571
+ # autocast fp32: +70 MB peak VRAM, 123 ms (vs 482 MB / 95 ms for bf16)
572
+ # model.float(): +324 MB peak VRAM, 149 ms
573
+ # Tested: both approaches produce bit-identical output.
574
+
575
+ with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
576
+ x = self.vocoder(mel_spec.float())
577
+ _, _, length_low_rate = x.shape
578
+ output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
579
+
580
+ # Pad to multiple of hop_length for exact mel frame count
581
+ remainder = length_low_rate % self.hop_length
582
+ if remainder != 0:
583
+ x = F.pad(x, (0, self.hop_length - remainder))
584
+
585
+ # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
586
+ mel = self._compute_mel(x)
587
+
588
+ # Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
589
+ mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
590
+ residual = self.bwe_generator(mel_for_bwe)
591
+ skip = self.resampler(x)
592
+ assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
593
+
594
+ return torch.clamp(residual + skip, -1, 1)[..., :output_length].to(input_dtype)
ltx2/ltx_core/model/common/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Common model utilities."""
2
+
3
+ from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
4
+
5
+ __all__ = [
6
+ "NormType",
7
+ "PixelNorm",
8
+ "build_normalization_layer",
9
+ ]