Spaces:
Running on Zero
Running on Zero
DramaBox Space — initial app + vendored ltx2
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +381 -0
- README.md +151 -31
- app.py +176 -0
- assets/silence_latent_frame.pt +3 -0
- configs/training_args.example.yaml +53 -0
- ltx2/ltx_core/__init__.py +0 -0
- ltx2/ltx_core/batch_split.py +95 -0
- ltx2/ltx_core/components/__init__.py +10 -0
- ltx2/ltx_core/components/diffusion_steps.py +106 -0
- ltx2/ltx_core/components/guiders.py +383 -0
- ltx2/ltx_core/components/noisers.py +35 -0
- ltx2/ltx_core/components/patchifiers.py +348 -0
- ltx2/ltx_core/components/protocols.py +101 -0
- ltx2/ltx_core/components/schedulers.py +130 -0
- ltx2/ltx_core/conditioning/__init__.py +19 -0
- ltx2/ltx_core/conditioning/exceptions.py +4 -0
- ltx2/ltx_core/conditioning/item.py +20 -0
- ltx2/ltx_core/conditioning/mask_utils.py +210 -0
- ltx2/ltx_core/conditioning/types/__init__.py +13 -0
- ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py +71 -0
- ltx2/ltx_core/conditioning/types/keyframe_cond.py +70 -0
- ltx2/ltx_core/conditioning/types/latent_cond.py +44 -0
- ltx2/ltx_core/conditioning/types/noise_mask_cond.py +45 -0
- ltx2/ltx_core/conditioning/types/reference_video_cond.py +91 -0
- ltx2/ltx_core/guidance/__init__.py +15 -0
- ltx2/ltx_core/guidance/perturbations.py +79 -0
- ltx2/ltx_core/layer_streaming.py +324 -0
- ltx2/ltx_core/loader/__init__.py +48 -0
- ltx2/ltx_core/loader/fuse_loras.py +133 -0
- ltx2/ltx_core/loader/kernels.py +72 -0
- ltx2/ltx_core/loader/module_ops.py +14 -0
- ltx2/ltx_core/loader/primitives.py +146 -0
- ltx2/ltx_core/loader/registry.py +84 -0
- ltx2/ltx_core/loader/sd_ops.py +139 -0
- ltx2/ltx_core/loader/sft_loader.py +66 -0
- ltx2/ltx_core/loader/single_gpu_model_builder.py +136 -0
- ltx2/ltx_core/modality_tiling.py +222 -0
- ltx2/ltx_core/model/__init__.py +8 -0
- ltx2/ltx_core/model/audio_vae/__init__.py +29 -0
- ltx2/ltx_core/model/audio_vae/attention.py +71 -0
- ltx2/ltx_core/model/audio_vae/audio_vae.py +508 -0
- ltx2/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
- ltx2/ltx_core/model/audio_vae/causality_axis.py +10 -0
- ltx2/ltx_core/model/audio_vae/downsample.py +110 -0
- ltx2/ltx_core/model/audio_vae/model_configurator.py +200 -0
- ltx2/ltx_core/model/audio_vae/ops.py +73 -0
- ltx2/ltx_core/model/audio_vae/resnet.py +176 -0
- ltx2/ltx_core/model/audio_vae/upsample.py +106 -0
- ltx2/ltx_core/model/audio_vae/vocoder.py +594 -0
- ltx2/ltx_core/model/common/__init__.py +9 -0
LICENSE
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LTX-2 Community License Agreement
|
| 2 |
+
License date: January 5, 2026
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
By using or distributing any portion or element of LTX-2, you agree
|
| 6 |
+
to be bound by this Agreement.
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"Agreement" means the terms and conditions for the license, use,
|
| 11 |
+
reproduction, and distribution of LTX-2 and the Complementary
|
| 12 |
+
Materials, as specified in this document.
|
| 13 |
+
|
| 14 |
+
"Control" means the direct or indirect ownership of more than
|
| 15 |
+
fifty percent (50%) of the voting securities or other ownership
|
| 16 |
+
interests, or the power to direct the management and policies of
|
| 17 |
+
such Entity through voting rights, contract, or otherwise.
|
| 18 |
+
|
| 19 |
+
"Data" means a collection of information and/or content extracted
|
| 20 |
+
from the dataset used with LTX-2, including to train, pretrain,
|
| 21 |
+
or otherwise evaluate LTX-2. The Data is not licensed under this
|
| 22 |
+
Agreement.
|
| 23 |
+
|
| 24 |
+
"Derivatives of LTX-2" means all modifications to LTX-2, works
|
| 25 |
+
based on LTX-2, or any other model which is created or initialized
|
| 26 |
+
by transfer of patterns of the weights, parameters, activations or
|
| 27 |
+
output of LTX-2, to the other model, in order to cause the other
|
| 28 |
+
model to perform similarly to LTX-2, including – but not limited
|
| 29 |
+
to - distillation methods entailing the use of intermediate data
|
| 30 |
+
representations or methods based on the generation of synthetic
|
| 31 |
+
data by LTX-2 for training the other model. For clarity, Derivatives
|
| 32 |
+
of LTX-2 include: (i) any fine-tuned or adapted weights, parameters,
|
| 33 |
+
or checkpoints derived from LTX-2; (ii) derivative model architectures
|
| 34 |
+
that incorporate or are based upon LTX-2's architecture; and
|
| 35 |
+
(iii) any modified or extended versions of the Complementary
|
| 36 |
+
Materials. All intellectual property rights in Derivatives of LTX-2
|
| 37 |
+
shall be subject to the terms of this Agreement, and you may not
|
| 38 |
+
claim exclusive ownership rights in any Derivatives of LTX-2 that
|
| 39 |
+
would restrict the rights granted herein.
|
| 40 |
+
|
| 41 |
+
"Entity" means any individual, corporation, partnership, limited
|
| 42 |
+
liability company, or other legal entity. For purposes of this
|
| 43 |
+
Agreement, an Entity shall be deemed to include, on an aggregative
|
| 44 |
+
basis, all subsidiaries, affiliates, and other companies under
|
| 45 |
+
common Control with such Entity. When determining whether an Entity
|
| 46 |
+
meets any threshold under this Agreement (including revenue
|
| 47 |
+
thresholds), all subsidiaries, affiliates, and companies under
|
| 48 |
+
common Control shall be considered collectively.
|
| 49 |
+
|
| 50 |
+
"Harm" includes but is not limited to physical, mental,
|
| 51 |
+
psychological, financial and reputational damage, pain, or loss.
|
| 52 |
+
|
| 53 |
+
"Licensor" or "Lightricks" means the owner that is granting the
|
| 54 |
+
license under this Agreement. For the purposes of this Agreement,
|
| 55 |
+
the Licensor is Lightricks Ltd.
|
| 56 |
+
|
| 57 |
+
"LTX-2" means the large language models, text/image/video/audio/3D
|
| 58 |
+
generation models, and multimodal large language models and their
|
| 59 |
+
software and algorithms, including trained model weights, parameters
|
| 60 |
+
(including optimizer states), machine-learning model code,
|
| 61 |
+
inference-enabling code, training-enabling code, fine-tuning
|
| 62 |
+
enabling code, accompanying source code, scripts, documentation,
|
| 63 |
+
tutorials, examples, and all other elements of the foregoing
|
| 64 |
+
distributed and made publicly available by Lightricks (including,
|
| 65 |
+
for example, at https://github.com/Lightricks/LTX-2) for the LTX-2
|
| 66 |
+
model released on January 5, 2026. This license is applicable to
|
| 67 |
+
all LTX-2 versions released since January 5, 2026, and all future
|
| 68 |
+
releases of LTX-2 under this license.
|
| 69 |
+
|
| 70 |
+
"Output" means the results of operating LTX-2 as embodied in
|
| 71 |
+
informational content resulting therefrom.
|
| 72 |
+
|
| 73 |
+
"you" (or "your") means an individual or legal Entity licensing
|
| 74 |
+
LTX-2 in accordance with this Agreement and/or making use of LTX-2
|
| 75 |
+
for whichever purpose and in any field of use, including usage of
|
| 76 |
+
LTX-2 in an end-use application - e.g. chatbot, translator, image
|
| 77 |
+
generator.
|
| 78 |
+
|
| 79 |
+
2. Grant of License. Subject to the terms and conditions of this
|
| 80 |
+
Agreement, you are granted a non-exclusive, worldwide,
|
| 81 |
+
non-transferable and royalty-free limited license under Licensor's
|
| 82 |
+
intellectual property or other rights owned by Licensor embodied
|
| 83 |
+
in LTX-2 to use, reproduce, prepare, distribute, publicly display,
|
| 84 |
+
publicly perform, sublicense, copy, create derivative works of,
|
| 85 |
+
and make modifications to LTX-2, for any purpose, subject to the
|
| 86 |
+
restrictions set forth in Attachment A; provided however, that
|
| 87 |
+
Entities with annual revenues of at least $10,000,000 (the
|
| 88 |
+
"Commercial Entities") are required to obtain a paid commercial
|
| 89 |
+
use license in order to use LTX-2 and Derivatives of LTX-2,
|
| 90 |
+
subject to the terms and provisions of a different license (the
|
| 91 |
+
"Commercial Use Agreement"), as will be provided by the Licensor.
|
| 92 |
+
Commercial Entities interested in such a commercial license are
|
| 93 |
+
required to [contact Licensor](https://ltx.io/model/licensing).
|
| 94 |
+
Any commercial use of LTX-2 or Derivatives of LTX-2 by the
|
| 95 |
+
Commercial Entities not in accordance with this Agreement and/or
|
| 96 |
+
the Commercial Use Agreement is strictly prohibited and shall be
|
| 97 |
+
deemed a material breach of this Agreement. Such material breach
|
| 98 |
+
will be subject, in addition to any license fees owed to Licensor
|
| 99 |
+
for the period such Commercial Entity used LTX-2 (as will be
|
| 100 |
+
determined by Licensor), to liquidated damages, which will be paid
|
| 101 |
+
to Licensor immediately upon demand, in an amount equal to double
|
| 102 |
+
the amount that would otherwise have been paid by you for the
|
| 103 |
+
relevant period of time. Such amount reflects a reasonable estimation
|
| 104 |
+
of the losses and administrative costs incurred due to such breach.
|
| 105 |
+
You agree and understand that this remedy does not limit the Licensor's
|
| 106 |
+
right to pursue other remedies available at law or equity.
|
| 107 |
+
|
| 108 |
+
3. Distribution and Redistribution. You may host for third parties
|
| 109 |
+
remote access purposes (e.g. software-as-a-service), reproduce
|
| 110 |
+
and distribute copies of LTX-2 or Derivatives of LTX-2 thereof in
|
| 111 |
+
any medium, with or without modifications, provided that you meet
|
| 112 |
+
the following conditions:
|
| 113 |
+
|
| 114 |
+
(a) Use-based restrictions as referenced in paragraph 4 and all
|
| 115 |
+
provisions of Attachment A MUST be included as an enforceable
|
| 116 |
+
provision by you in any type of legal agreement (e.g. a
|
| 117 |
+
license) governing the use and/or distribution of LTX-2 or
|
| 118 |
+
Derivatives of LTX-2, and you shall give notice to subsequent
|
| 119 |
+
users you distribute to, that LTX-2 or Derivatives of LTX-2
|
| 120 |
+
are subject to paragraph 4 and Attachment A in their entirety,
|
| 121 |
+
including all use restrictions and acceptable use policies;
|
| 122 |
+
|
| 123 |
+
(b) You must provide any third party recipients of LTX-2 or
|
| 124 |
+
Derivatives of LTX-2 a copy of this Agreement, including all
|
| 125 |
+
attachments and use policies. Any Derivative of LTX-2 (as
|
| 126 |
+
defined in Section 1, including but not limited to fine-tuned
|
| 127 |
+
weights, modified training code, models trained on Outputs, or
|
| 128 |
+
any other derivative) must be distributed exclusively under
|
| 129 |
+
the terms of this Agreement with a complete copy of this
|
| 130 |
+
license included;
|
| 131 |
+
|
| 132 |
+
(c) You must cause any modified files to carry prominent notices
|
| 133 |
+
stating that you changed the files;
|
| 134 |
+
|
| 135 |
+
(d) You must retain all copyright, patent, trademark, and
|
| 136 |
+
attribution notices excluding those notices that do not
|
| 137 |
+
pertain to any part of LTX-2, Derivatives of LTX-2.
|
| 138 |
+
|
| 139 |
+
You may add your own copyright statement to your modifications and
|
| 140 |
+
may provide additional or different license terms and conditions -
|
| 141 |
+
respecting paragraph 3(a) - for use, reproduction, or distribution
|
| 142 |
+
of your modifications, or for any such Derivatives of LTX-2 as a
|
| 143 |
+
whole, provided your use, reproduction, and distribution of LTX-2
|
| 144 |
+
otherwise complies with the conditions stated in this Agreement,
|
| 145 |
+
and you provide a complete copy of this Agreement with any such
|
| 146 |
+
use, reproduction and distribution of LTX-2 and any Derivatives
|
| 147 |
+
thereof.
|
| 148 |
+
|
| 149 |
+
4. Use-based restrictions. The restrictions set forth in Attachment A
|
| 150 |
+
are considered Use-based restrictions. Therefore, you cannot use
|
| 151 |
+
LTX-2 and the Derivatives of LTX-2 in violation of the specified
|
| 152 |
+
restricted uses. You may use LTX-2 subject to this Agreement,
|
| 153 |
+
including only for lawful purposes and in accordance with the
|
| 154 |
+
Agreement. "Use" may include creating any content with, fine-tuning,
|
| 155 |
+
updating, running, training, evaluating and/or re-parametrizing
|
| 156 |
+
LTX-2. You shall require all of your users who use LTX-2 or a
|
| 157 |
+
Derivative of LTX-2 to comply with the terms of this paragraph 4.
|
| 158 |
+
|
| 159 |
+
5. The Output You Generate. Except as set forth herein, Licensor
|
| 160 |
+
claims no rights in the Output you generate using LTX-2. You are
|
| 161 |
+
accountable for input you insert into LTX-2, the Output you
|
| 162 |
+
generate and its subsequent uses. No use of the Output can
|
| 163 |
+
contravene any provision as stated in the Agreement.
|
| 164 |
+
|
| 165 |
+
6. Updates and Runtime Restrictions. To the maximum extent permitted
|
| 166 |
+
by law, Licensor reserves the right to restrict (remotely or
|
| 167 |
+
otherwise) usage of LTX-2 in violation of this Agreement, update
|
| 168 |
+
LTX-2 through electronic means, or modify the Output of LTX-2
|
| 169 |
+
based on updates. You shall undertake reasonable efforts to use
|
| 170 |
+
the latest version of LTX-2. Any use of the non-current version
|
| 171 |
+
of LTX-2 is done solely at your risk.
|
| 172 |
+
|
| 173 |
+
7. Export Controls and Sanctions Compliance. You acknowledge that
|
| 174 |
+
LTX-2, Derivatives of LTX-2 may be subject to export control laws
|
| 175 |
+
and regulations, including but not limited to the U.S. Export
|
| 176 |
+
Administration Regulations and sanctions programs administered by
|
| 177 |
+
the Office of Foreign Assets Control (OFAC). You represent and
|
| 178 |
+
warrant that you and any users of LTX-2 are not (i) located in,
|
| 179 |
+
organized under the laws of, or ordinarily resident in any country
|
| 180 |
+
or territory subject to comprehensive sanctions; (ii) identified
|
| 181 |
+
on any U.S. government restricted party list, including the
|
| 182 |
+
Specially Designated Nationals and Blocked Persons List; or
|
| 183 |
+
(iii) otherwise prohibited from receiving LTX-2 under applicable
|
| 184 |
+
law. You shall not export, re-export, or transfer LTX-2, directly
|
| 185 |
+
or indirectly, in violation of any applicable export control or
|
| 186 |
+
sanctions laws or regulations. You agree to comply with all
|
| 187 |
+
applicable trade control laws and shall indemnify and hold
|
| 188 |
+
Licensor harmless from any claims arising from your failure to
|
| 189 |
+
comply with such laws.
|
| 190 |
+
|
| 191 |
+
8. Trademarks and related. Nothing in this Agreement permits you to
|
| 192 |
+
make use of Licensor's trademarks, trade names, logos or to
|
| 193 |
+
otherwise suggest endorsement or misrepresent the relationship
|
| 194 |
+
between the parties; and any rights not expressly granted herein
|
| 195 |
+
are reserved by the Licensor.
|
| 196 |
+
|
| 197 |
+
9. Disclaimer of Warranty. Unless required by applicable law or
|
| 198 |
+
agreed to in writing, Licensor provides LTX-2 on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 200 |
+
implied, including, without limitation, any warranties or
|
| 201 |
+
conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS
|
| 202 |
+
FOR A PARTICULAR PURPOSE. You are solely responsible for
|
| 203 |
+
determining the appropriateness of using or redistributing LTX-2
|
| 204 |
+
and Derivatives of LTX-2 and assume any risks associated with
|
| 205 |
+
your exercise of permissions under this Agreement.
|
| 206 |
+
|
| 207 |
+
10. Limitation of Liability. In no event and under no legal theory,
|
| 208 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 209 |
+
unless required by applicable law (such as deliberate and grossly
|
| 210 |
+
negligent acts) or agreed to in writing, shall Licensor be liable
|
| 211 |
+
to you for damages, including any direct, indirect, special,
|
| 212 |
+
incidental, or consequential damages of any character arising as
|
| 213 |
+
a result of this Agreement or out of the use or inability to use
|
| 214 |
+
LTX-2 (including but not limited to damages for loss of goodwill,
|
| 215 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 216 |
+
other commercial damages or losses), even if Licensor has been
|
| 217 |
+
advised of the possibility of such damages.
|
| 218 |
+
|
| 219 |
+
11. Accepting Warranty or Additional Liability. While redistributing
|
| 220 |
+
LTX-2 and Derivatives of LTX-2, you may, provided you do not
|
| 221 |
+
violate the terms of this Agreement, choose to offer and charge
|
| 222 |
+
a fee for, acceptance of support, warranty, indemnity, or other
|
| 223 |
+
liability obligations. However, in accepting such obligations,
|
| 224 |
+
you may act only on your own behalf and on your sole
|
| 225 |
+
responsibility, not on behalf of Licensor, and only if you agree
|
| 226 |
+
to indemnify, defend, and hold Licensor harmless for any liability
|
| 227 |
+
incurred by, or claims asserted against Licensor, by reason of
|
| 228 |
+
your accepting any such warranty or additional liability.
|
| 229 |
+
|
| 230 |
+
12. Governing Law. This Agreement and all relations, disputes, claims
|
| 231 |
+
and other matters arising hereunder (including non-contractual
|
| 232 |
+
disputes or claims) will be governed exclusively by, and construed
|
| 233 |
+
exclusively in accordance with, the laws of the State of New York.
|
| 234 |
+
To the extent permitted by law, choice of laws rules and the
|
| 235 |
+
United Nations Convention on Contracts for the International Sale
|
| 236 |
+
of Goods will not apply. For the purposes of adjudicating any
|
| 237 |
+
action or proceeding to enforce the terms of this Agreement, you
|
| 238 |
+
hereby irrevocably consent to the exclusive jurisdiction of, and
|
| 239 |
+
venue in, the federal and state courts located in the County of
|
| 240 |
+
New York within the State of New York. The prevailing party in
|
| 241 |
+
any claim or dispute between the parties under this Agreement
|
| 242 |
+
will be entitled to reimbursement of its reasonable attorneys'
|
| 243 |
+
fees and costs. You hereby waive the right to a trial by jury,
|
| 244 |
+
to participate in a class or representative action (including in
|
| 245 |
+
arbitration), or to combine individual proceedings in court or
|
| 246 |
+
in arbitration without the consent of all parties.
|
| 247 |
+
|
| 248 |
+
13. Term and Termination. This Agreement is effective upon your
|
| 249 |
+
acceptance and continues until terminated. Licensor may terminate
|
| 250 |
+
this Agreement immediately upon written notice to you if you
|
| 251 |
+
breach any provision of this Agreement, including but not limited
|
| 252 |
+
to violations of the use restrictions in Attachment A or
|
| 253 |
+
unauthorized commercial use. Upon termination: (a) all rights
|
| 254 |
+
granted to you under this Agreement will immediately cease;
|
| 255 |
+
(b) you must immediately cease all use of LTX-2 and Derivatives
|
| 256 |
+
of LTX-2; (c) you must delete or destroy all copies of LTX-2
|
| 257 |
+
and Derivatives of LTX-2 in your possession or control; and
|
| 258 |
+
(d) you must notify any third parties to whom you distributed
|
| 259 |
+
LTX-2 or Derivatives of LTX-2 of the termination. Sections 8-13,
|
| 260 |
+
and Section 15 shall survive termination of this Agreement.
|
| 261 |
+
Termination does not relieve you of any obligations incurred
|
| 262 |
+
prior to termination, including payment obligations under
|
| 263 |
+
Section 2. In addition, if You commence a lawsuit or other
|
| 264 |
+
proceedings (including a cross-claim or counterclaim in a lawsuit)
|
| 265 |
+
against Licensor or any person or entity alleging that LTX-2 or
|
| 266 |
+
any Output, or any portion of any of the foregoing, infringe any
|
| 267 |
+
intellectual property or other right owned or licensable by you,
|
| 268 |
+
then all licenses granted to you under this Agreement shall
|
| 269 |
+
terminate as of the date such lawsuit or other proceeding is filed.
|
| 270 |
+
|
| 271 |
+
14. Disputes and Arbitration. All disputes arising in connection with
|
| 272 |
+
this Agreement shall be finally settled by arbitration under the
|
| 273 |
+
Rules of Arbitration of the International Chamber of Commerce
|
| 274 |
+
("ICC Rules"), by one (1) arbitrator appointed in accordance with
|
| 275 |
+
the ICC Rules. The seat of arbitration shall be New York, NY, USA,
|
| 276 |
+
and the proceedings shall be conducted in English. The arbitrator
|
| 277 |
+
shall be empowered to grant any relief that a court could grant.
|
| 278 |
+
Judgment on the arbitration award may be entered by any court
|
| 279 |
+
having jurisdiction thereof. Each party waives its right to a
|
| 280 |
+
trial by jury and to participate in any class or representative
|
| 281 |
+
action.
|
| 282 |
+
|
| 283 |
+
15. If any provision of this Agreement is held to be
|
| 284 |
+
invalid, illegal
|
| 285 |
+
or unenforceable, the remaining provisions shall be unaffected
|
| 286 |
+
thereby and remain valid as if such provision had not been set
|
| 287 |
+
forth herein.
|
| 288 |
+
|
| 289 |
+
END OF TERMS AND CONDITIONS
|
| 290 |
+
|
| 291 |
+
ATTACHMENT A: Use Restrictions
|
| 292 |
+
|
| 293 |
+
When using the Outputs, LTX-2 and any Derivatives thereof, you
|
| 294 |
+
will comply with the Acceptable Use Policy. In addition, you
|
| 295 |
+
agree not to use the Outputs, LTX-2 or its Derivatives in any
|
| 296 |
+
of the following ways:
|
| 297 |
+
|
| 298 |
+
1. In any way that violates any applicable national, federal,
|
| 299 |
+
state, local or international law or regulation;
|
| 300 |
+
|
| 301 |
+
2. For the purpose of exploiting, Harming or attempting to
|
| 302 |
+
exploit or Harm minors in any way;
|
| 303 |
+
|
| 304 |
+
3. To generate or disseminate false information and/or content
|
| 305 |
+
with the purpose of Harming others;
|
| 306 |
+
|
| 307 |
+
4. To generate or disseminate personal identifiable information
|
| 308 |
+
that can be used to Harm an individual;
|
| 309 |
+
|
| 310 |
+
5. To generate or disseminate information and/or content (e.g.
|
| 311 |
+
images, code, posts, articles), and place the information
|
| 312 |
+
and/or content in any context (e.g. bot generating tweets)
|
| 313 |
+
without expressly and intelligibly disclaiming that the
|
| 314 |
+
information and/or content is machine generated;
|
| 315 |
+
|
| 316 |
+
6. To defame, disparage or otherwise harass others;
|
| 317 |
+
|
| 318 |
+
7. To impersonate or attempt to impersonate (e.g. deepfakes)
|
| 319 |
+
others without their consent;
|
| 320 |
+
|
| 321 |
+
8. For fully automated decision making that adversely impacts an
|
| 322 |
+
individual's legal rights or otherwise creates or modifies a
|
| 323 |
+
binding, enforceable obligation;
|
| 324 |
+
|
| 325 |
+
9. For any use intended to or which has the effect of
|
| 326 |
+
discriminating against or Harming individuals or groups based
|
| 327 |
+
on online or offline social behavior or known or predicted
|
| 328 |
+
personal or personality characteristics;
|
| 329 |
+
|
| 330 |
+
10. To exploit any of the vulnerabilities of a specific group of
|
| 331 |
+
persons based on their age, social, physical or mental
|
| 332 |
+
characteristics, in order to materially distort the behavior
|
| 333 |
+
of a person pertaining to that group in a manner that causes
|
| 334 |
+
or is likely to cause that person or another person physical
|
| 335 |
+
or psychological Harm;
|
| 336 |
+
|
| 337 |
+
11. For any use intended to or which has the effect of
|
| 338 |
+
discriminating against individuals or groups based on legally
|
| 339 |
+
protected characteristics or categories;
|
| 340 |
+
|
| 341 |
+
12. To provide medical advice and medical results interpretation;
|
| 342 |
+
|
| 343 |
+
13. To generate or disseminate information for the purpose to be
|
| 344 |
+
used for administration of justice, law enforcement,
|
| 345 |
+
immigration or asylum processes, such as predicting an
|
| 346 |
+
individual will commit fraud/crime commitment (e.g. by text
|
| 347 |
+
profiling, drawing causal relationships between assertions
|
| 348 |
+
made in documents, indiscriminate and arbitrarily-targeted use);
|
| 349 |
+
|
| 350 |
+
14. To generate and/or disseminate malware (including – but not
|
| 351 |
+
limited to – ransomware) or any other content to be used for
|
| 352 |
+
the purpose of harming electronic systems;
|
| 353 |
+
|
| 354 |
+
15. To engage in, promote, incite, or facilitate discrimination
|
| 355 |
+
or other unlawful or harmful conduct in the provision of
|
| 356 |
+
employment, employment benefits, credit, housing, or other
|
| 357 |
+
essential goods and services;
|
| 358 |
+
|
| 359 |
+
16. To engage in, promote, incite, or facilitate the harassment,
|
| 360 |
+
abuse, threatening, or bullying of individuals or groups of
|
| 361 |
+
individuals;
|
| 362 |
+
|
| 363 |
+
17. For military, warfare, nuclear industries or applications,
|
| 364 |
+
weapons development, or any use in connection with activities
|
| 365 |
+
that may cause death, personal injury, or severe physical or
|
| 366 |
+
environmental damage;
|
| 367 |
+
|
| 368 |
+
18. For commercial use only: To train, improve, or fine-tune any
|
| 369 |
+
other machine learning model, artificial intelligence system,
|
| 370 |
+
or competing model, except for Derivatives of LTX-2 as
|
| 371 |
+
expressly permitted under this Agreement;
|
| 372 |
+
|
| 373 |
+
19. To circumvent, disable, or interfere with any technical
|
| 374 |
+
limitations, safety features, content filters, or use
|
| 375 |
+
restrictions implemented in LTX-2 by Licensor;
|
| 376 |
+
|
| 377 |
+
20. To use LTX-2 or Derivatives of LTX-2 in any product, service,
|
| 378 |
+
or application that directly competes with Licensor's
|
| 379 |
+
commercial products or services, or is designed to replace or
|
| 380 |
+
substitute Licensor's offerings in the market, without
|
| 381 |
+
obtaining a separate commercial license from Licensor.
|
README.md
CHANGED
|
@@ -1,42 +1,162 @@
|
|
| 1 |
-
-
|
| 2 |
-
title: DramaBox
|
| 3 |
-
emoji: 🎭
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 4.44.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: true
|
| 10 |
-
license: other
|
| 11 |
-
license_name: ltx-2-community
|
| 12 |
-
license_link: https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE
|
| 13 |
-
hardware: l40s
|
| 14 |
-
short_description: Expressive TTS with voice cloning — DramaBox demo
|
| 15 |
-
---
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
```
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
```
|
| 28 |
|
| 29 |
-
|
| 30 |
-
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
##
|
| 36 |
|
| 37 |
-
-
|
| 38 |
-
- `src/inference_server.py` — warm `TTSServer` (single load, ~2.5s/request)
|
| 39 |
-
- `src/inference.py` — CLI inference
|
| 40 |
-
- `src/model_downloader.py` — auto-fetches model from HuggingFace
|
| 41 |
-
- `ltx2/` — vendored LTX-2 pipelines
|
| 42 |
-
- `requirements.txt` — Python deps (includes `resemble-perth`)
|
|
|
|
| 1 |
+
# Dramabox - Expressive TTS with Voice Cloning
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
Prompt-driven TTS with voice cloning built on a 3.3B Diffusion Transformer with flow matching.
|
| 4 |
|
| 5 |
+
## Folder Structure
|
| 6 |
|
| 7 |
+
```
|
| 8 |
+
DramaBox/
|
| 9 |
+
├── src/
|
| 10 |
+
│ ├── inference.py # TTS inference with voice cloning
|
| 11 |
+
│ ├── inference_server.py # Warm server (~2.5s per generation)
|
| 12 |
+
│ ├── audio_conditioning.py # Reference audio conditioning
|
| 13 |
+
│ └── model_downloader.py # Auto-download models from HuggingFace
|
| 14 |
+
├── patches/
|
| 15 |
+
│ ├── attention.py # dtype fix for mask allocation
|
| 16 |
+
│ └── guiders.py # Per-token CFG clamping
|
| 17 |
+
├── assets/
|
| 18 |
+
│ └── silence_latent_frame.pt
|
| 19 |
+
├── evals/
|
| 20 |
+
│ ├── eval_short.txt # 30 short prompts (~5-15s)
|
| 21 |
+
│ ├── eval_long.txt # 15 long prompts (~20-37s)
|
| 22 |
+
│ └── eval_expressive.txt # 15 expressive prompts (laughs, sighs, stammers)
|
| 23 |
+
├── scripts/
|
| 24 |
+
│ ├── inference.sh # Inference wrapper
|
| 25 |
+
│ └── eval.sh # Evaluation runner
|
| 26 |
+
├── app.py # Gradio demo app
|
| 27 |
+
├── ltx2/ # LTX-2 dependency packages
|
| 28 |
+
└── README.md
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Models
|
| 32 |
+
|
| 33 |
+
Models auto-download from [ResembleAI/Dramabox](https://huggingface.co/ResembleAI/Dramabox) on HuggingFace.
|
| 34 |
+
|
| 35 |
+
| Model | Size | Description |
|
| 36 |
+
|-------|------|-------------|
|
| 37 |
+
| `dramabox-dit-v1.safetensors` | 6.6 GB | DiT transformer |
|
| 38 |
+
| `dramabox-audio-components.safetensors` | 2.7 GB | Audio VAE + vocoder + text projection |
|
| 39 |
+
| [unsloth/gemma-3-12b-it-bnb-4bit](https://huggingface.co/unsloth/gemma-3-12b-it-bnb-4bit) | ~8 GB | Text encoder (auto-downloaded) |
|
| 40 |
+
|
| 41 |
+
**VRAM**: ~24 GB peak | **Speed**: ~2.5s per generation (warm server, H100)
|
| 42 |
+
|
| 43 |
+
## Quick Start
|
| 44 |
+
|
| 45 |
+
### Warm Server (recommended, ~2.5s per request)
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
from src.inference_server import TTSServer
|
| 49 |
+
|
| 50 |
+
server = TTSServer(device="cuda")
|
| 51 |
+
|
| 52 |
+
server.generate_to_file(
|
| 53 |
+
prompt='A woman speaks warmly, "Hello, how are you today?" She laughs, "Hahaha, it is so good to see you!"',
|
| 54 |
+
output="output.wav",
|
| 55 |
+
voice_ref="reference.wav", # optional, 10+ seconds
|
| 56 |
+
)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Gradio App
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
GEMINI_API_KEY=your_key CUDA_VISIBLE_DEVICES=4 python app.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### CLI Inference
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
python src/inference.py \
|
| 69 |
+
--voice-sample reference.wav \
|
| 70 |
+
--prompt 'A woman speaks warmly, "Hello, how are you today?"' \
|
| 71 |
+
--output output.wav \
|
| 72 |
+
--cfg-scale 2.5 --stg-scale 1.5
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Evaluation
|
| 76 |
|
| 77 |
+
```bash
|
| 78 |
+
bash scripts/eval.sh --eval expressive --output eval_results/
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Inference Settings
|
| 82 |
+
|
| 83 |
+
| Parameter | Default | Notes |
|
| 84 |
+
|-----------|---------|-------|
|
| 85 |
+
| cfg-scale | 2.5 | Lower = more natural, higher = more text following |
|
| 86 |
+
| stg-scale | 1.5 | Skip-token guidance |
|
| 87 |
+
| rescale | 0 | No rescaling |
|
| 88 |
+
| modality | 1 | No modality guidance |
|
| 89 |
+
| duration-multiplier | 1.1 | 10% breathing room |
|
| 90 |
+
| steps | 30 | Euler flow matching |
|
| 91 |
+
|
| 92 |
+
## Prompt Writing Guide
|
| 93 |
+
|
| 94 |
+
**Structure:** `<speaker description>, "<dialogue>" <action direction> "<more dialogue>"`
|
| 95 |
+
|
| 96 |
+
### What works inside quotes (model produces actual sounds)
|
| 97 |
+
- Laughs: `"Hahaha"` `"Hehehe"` (always one word, never separated)
|
| 98 |
+
- Sounds: `"Mmmmm"` `"Ugh"` `"Argh"` `"Ahhh"` `"Hmm"`
|
| 99 |
+
|
| 100 |
+
### What goes outside quotes (stage directions)
|
| 101 |
+
- `She sighs deeply.` `He gulps nervously.` `A long pause.`
|
| 102 |
+
- `Her voice cracks.` `He clears his throat.` `She scoffs.`
|
| 103 |
+
|
| 104 |
+
### Never inside quotes (model speaks them literally)
|
| 105 |
+
- Ahem, Pfft, Sigh, Gasp, Cough
|
| 106 |
+
|
| 107 |
+
### Tips
|
| 108 |
+
- Match gender/age in speaker description to voice reference
|
| 109 |
+
- Break long dialogue into segments with acting directions between them
|
| 110 |
+
- End prompt at the last closing quote mark (no trailing descriptions)
|
| 111 |
+
|
| 112 |
+
## Watermarking
|
| 113 |
|
| 114 |
+
Every audio output from `inference.py` and `inference_server.TTSServer.generate_to_file` is automatically watermarked with [Resemble Perth](https://github.com/resemble-ai/Perth) — an imperceptible neural watermark that survives MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
import perth, librosa
|
| 118 |
+
wav, sr = librosa.load("output.wav", sr=None, mono=True)
|
| 119 |
+
detector = perth.PerthImplicitWatermarker()
|
| 120 |
+
print(detector.get_watermark(wav, sample_rate=sr)) # confidence ≈ 1.0 for our outputs
|
| 121 |
```
|
| 122 |
+
|
| 123 |
+
Pass `--no-watermark` to `inference.py` (or `watermark=False` to `generate_to_file`) to disable for debugging.
|
| 124 |
+
|
| 125 |
+
## Training
|
| 126 |
+
|
| 127 |
+
DramaBox is an IC-LoRA fine-tune of the LTX-2.3 22B audio-only branch. To train your own:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# 1. Preprocess raw (audio, transcript) pairs → audio_latents/ + conditions/
|
| 131 |
+
python src/preprocess.py \
|
| 132 |
+
--dataset-type manifest \
|
| 133 |
+
--index your_data.jsonl \
|
| 134 |
+
--output-dir /path/to/preprocessed/ \
|
| 135 |
+
--checkpoint dramabox-audio-components.safetensors \
|
| 136 |
+
--gemma-root /path/to/gemma-3-12b-it-bnb-4bit/
|
| 137 |
+
|
| 138 |
+
# 2. Edit configs/training_args.example.yaml → your data paths
|
| 139 |
+
|
| 140 |
+
# 3. Launch (uses HuggingFace accelerate)
|
| 141 |
+
bash scripts/train.sh \
|
| 142 |
+
--config configs/training_args.example.yaml \
|
| 143 |
+
--gpus 0,1,2,3,4,5,6 \
|
| 144 |
+
--train-val-gpu 7
|
| 145 |
```
|
| 146 |
|
| 147 |
+
| Script | Purpose |
|
| 148 |
+
|---|---|
|
| 149 |
+
| `src/preprocess.py` | Encode audio (Audio VAE) + text (Gemma) into training-ready `.pt` files |
|
| 150 |
+
| `src/train.py` | IC-LoRA training loop with peft, accelerate multi-GPU, periodic validation |
|
| 151 |
+
| `src/validate.py` | Spawned by `train.py` at each save step; runs the warm validator on a held-out prompt set |
|
| 152 |
+
| `scripts/train.sh` | YAML-config wrapper around `accelerate launch src/train.py` |
|
| 153 |
+
|
| 154 |
+
LoRA targets the audio branch only: `audio_attn1.{to_q,to_k,to_v,to_out.0}` + `audio_ff.{net.0.proj,net.2}` × 48 transformer blocks (288 LoRA pairs total). Default rank 128 / alpha 128 / dropout 0.1, cosine LR schedule from 1e-4 with 500-step warmup over 10k steps.
|
| 155 |
+
|
| 156 |
+
## Language
|
| 157 |
|
| 158 |
+
English.
|
| 159 |
|
| 160 |
+
## License
|
| 161 |
|
| 162 |
+
Built on [LTX-2](https://github.com/Lightricks/LTX-2) by Lightricks. Distributed under the LTX-2 Community License Agreement — see [`LICENSE`](LICENSE).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""DramaBox — Gradio demo (warm server).
|
| 3 |
+
|
| 4 |
+
Loads the warm TTSServer once, then handles requests at ~2.5 s each. All
|
| 5 |
+
generated audio is invisibly watermarked with Resemble Perth before being
|
| 6 |
+
returned to the user.
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import tempfile
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
# Local src import.
|
| 17 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
|
| 18 |
+
from inference_server import TTSServer # noqa: E402
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
| 22 |
+
logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...")
|
| 23 |
+
tts = TTSServer(
|
| 24 |
+
device="cuda",
|
| 25 |
+
dtype=os.environ.get("LTX_DTYPE", "bf16"),
|
| 26 |
+
compile_model=os.environ.get("LTX_COMPILE", "0") == "1",
|
| 27 |
+
bnb_4bit=True, # default Gemma is unsloth pre-quantized
|
| 28 |
+
)
|
| 29 |
+
logging.info("Server ready.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Example prompts (shown as click-to-fill chips in the UI) ─────────────────
|
| 33 |
+
EXAMPLES: list[tuple[str, str]] = [
|
| 34 |
+
(
|
| 35 |
+
"Villain monologue",
|
| 36 |
+
'A shadowy villain speaks with cold menace, "You have entered my domain, mortal." '
|
| 37 |
+
'He chuckles darkly, "Such arrogance will be your undoing." '
|
| 38 |
+
'His voice rises with fury, "Kneel, or be destroyed where you stand!"'
|
| 39 |
+
),
|
| 40 |
+
(
|
| 41 |
+
"Talk-show host wheeze-laugh",
|
| 42 |
+
'A talk show host gasps with shock, "No! You did NOT just say that!" '
|
| 43 |
+
'He bursts into uncontrollable laughter, "Hahaha! Oh my god, oh my god!" '
|
| 44 |
+
'He wheezes, "I cannot, I literally cannot breathe right now!"'
|
| 45 |
+
),
|
| 46 |
+
(
|
| 47 |
+
"Tender goodnight whisper",
|
| 48 |
+
'A woman speaks tenderly, "It has been a long day, my love." '
|
| 49 |
+
'She whispers, "Close your eyes. I am right here." '
|
| 50 |
+
'She hums quietly, "Mmmm-mmm. Sleep now."'
|
| 51 |
+
),
|
| 52 |
+
(
|
| 53 |
+
"Old-school radio anchor",
|
| 54 |
+
'A radio host clears his throat, "Excuse me, pardon that." '
|
| 55 |
+
'He settles into a warm, professional tone, "Good evening everyone, '
|
| 56 |
+
'and welcome back to the show. We have got a wonderful lineup tonight."'
|
| 57 |
+
),
|
| 58 |
+
(
|
| 59 |
+
"Catgirl uncontrollable giggling",
|
| 60 |
+
'A playful girl already mid-giggle, "Hehehe, oh my gosh you should see your face!" '
|
| 61 |
+
'She gasps for air between giggles, "Oh my, hehe, oh my, I cannot stop!" '
|
| 62 |
+
'She tries to compose herself, "Ahhhhh okay okay okay, I will stop, I promise."'
|
| 63 |
+
),
|
| 64 |
+
(
|
| 65 |
+
"Hero stammering courage",
|
| 66 |
+
'A young warrior speaks with a trembling voice, "I... I do not know if I can do this." '
|
| 67 |
+
'He takes a shaky breath, "But someone has to try." '
|
| 68 |
+
'His voice steadies with growing fire, "No more running. I WILL fight!"'
|
| 69 |
+
),
|
| 70 |
+
(
|
| 71 |
+
"Exhausted dad, fraying patience",
|
| 72 |
+
'An exhausted father speaks with fraying patience, "Sweetie, daddy is asking very nicely." '
|
| 73 |
+
'He sighs deeply, "Ohhhh my goodness." '
|
| 74 |
+
'He puts on an overly cheerful voice, "Hey buddy! Look at the shiny thing!" '
|
| 75 |
+
'Then he laughs helplessly, "Hahaha, I am losing my mind."'
|
| 76 |
+
),
|
| 77 |
+
(
|
| 78 |
+
"Smug-confident announcer",
|
| 79 |
+
'A confident announcer speaks proudly, "And now, the moment you have all been waiting for." '
|
| 80 |
+
'He chuckles knowingly, "Heheh, trust me, this one is going to blow you away."'
|
| 81 |
+
),
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float, seed: int):
|
| 86 |
+
if not prompt or not prompt.strip():
|
| 87 |
+
raise gr.Error("Prompt is empty.")
|
| 88 |
+
t0 = time.time()
|
| 89 |
+
ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None
|
| 90 |
+
output = tempfile.mktemp(suffix=".wav", prefix="dramabox_")
|
| 91 |
+
tts.generate_to_file(
|
| 92 |
+
prompt=prompt,
|
| 93 |
+
output=output,
|
| 94 |
+
voice_ref=ref_path,
|
| 95 |
+
cfg_scale=cfg, stg_scale=stg,
|
| 96 |
+
duration_multiplier=dur_mult, seed=int(seed),
|
| 97 |
+
)
|
| 98 |
+
elapsed = time.time() - t0
|
| 99 |
+
logging.info(f"Generated in {elapsed:.2f}s -> {output}")
|
| 100 |
+
return output
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ── UI ──────────────────────────────────────────────────────────────────────
|
| 104 |
+
with gr.Blocks(
|
| 105 |
+
title="DramaBox — Expressive TTS",
|
| 106 |
+
theme=gr.themes.Default(),
|
| 107 |
+
css=".prompt-box textarea { font-size: 14px !important; line-height: 1.5 !important; }",
|
| 108 |
+
) as app:
|
| 109 |
+
gr.Markdown("# 🎭 DramaBox — Expressive TTS with Voice Cloning")
|
| 110 |
+
gr.Markdown(
|
| 111 |
+
"Write a scene prompt, optionally upload a 10-second voice reference, "
|
| 112 |
+
"and generate. Audio is automatically watermarked with "
|
| 113 |
+
"[Resemble Perth](https://github.com/resemble-ai/Perth).\n\n"
|
| 114 |
+
"**Tips:** put dialogue inside `\"double quotes\"`, scene directions outside. "
|
| 115 |
+
"Phonetic sounds (`\"Hahaha\"`, `\"Mmmm\"`, `\"Ugh\"`) go inside quotes; named "
|
| 116 |
+
"actions (`She sighs.`, `He clears his throat.`) go outside."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
with gr.Row():
|
| 120 |
+
with gr.Column(scale=3):
|
| 121 |
+
prompt_box = gr.Textbox(
|
| 122 |
+
label="Scene prompt",
|
| 123 |
+
placeholder=EXAMPLES[0][1],
|
| 124 |
+
lines=6, elem_classes=["prompt-box"],
|
| 125 |
+
)
|
| 126 |
+
example_chooser = gr.Dropdown(
|
| 127 |
+
choices=[e[0] for e in EXAMPLES],
|
| 128 |
+
label="Load an example prompt", interactive=True, value=None,
|
| 129 |
+
)
|
| 130 |
+
audio_ref = gr.Audio(
|
| 131 |
+
label="Voice reference (optional, 10+ seconds)",
|
| 132 |
+
type="filepath",
|
| 133 |
+
)
|
| 134 |
+
gen_btn = gr.Button("Generate", variant="primary", size="lg")
|
| 135 |
+
|
| 136 |
+
with gr.Column(scale=2):
|
| 137 |
+
with gr.Accordion("Inference settings", open=False):
|
| 138 |
+
cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
|
| 139 |
+
stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
|
| 140 |
+
dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05, label="Duration ×")
|
| 141 |
+
seed_input = gr.Number(value=42, label="Seed", precision=0)
|
| 142 |
+
audio_out = gr.Audio(label="Generated audio", type="filepath")
|
| 143 |
+
with gr.Accordion("Prompt writing guide", open=False):
|
| 144 |
+
gr.Markdown(
|
| 145 |
+
"**Structure:** `<speaker description>, \"<dialogue>\" <action> \"<more dialogue>\"`\n\n"
|
| 146 |
+
"**Inside quotes** (model speaks them):\n"
|
| 147 |
+
"- Dialogue: `\"Hello, how are you?\"`\n"
|
| 148 |
+
"- Phonetic sounds: `\"Hahaha\"`, `\"Hehehe\"`, `\"Mmmmm\"`, `\"Ugh\"`, `\"Argh\"`\n\n"
|
| 149 |
+
"**Outside quotes** (stage directions):\n"
|
| 150 |
+
"- `She sighs deeply.`, `He gulps nervously.`, `A long pause.`\n"
|
| 151 |
+
"- `Her voice cracks.`, `He clears his throat.`\n\n"
|
| 152 |
+
"**Avoid inside quotes:** Ahem, Pfft, Sigh, Gasp, Cough — the model speaks them literally."
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def _load_example(choice: str):
|
| 156 |
+
if not choice:
|
| 157 |
+
return gr.update()
|
| 158 |
+
for name, prompt in EXAMPLES:
|
| 159 |
+
if name == choice:
|
| 160 |
+
return prompt
|
| 161 |
+
return gr.update()
|
| 162 |
+
|
| 163 |
+
example_chooser.change(_load_example, inputs=[example_chooser], outputs=[prompt_box])
|
| 164 |
+
gen_btn.click(
|
| 165 |
+
on_generate,
|
| 166 |
+
inputs=[prompt_box, audio_ref, cfg_slider, stg_slider, dur_slider, seed_input],
|
| 167 |
+
outputs=[audio_out],
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
port = int(os.environ.get("GRADIO_SERVER_PORT", "7861"))
|
| 173 |
+
app.queue(max_size=10).launch(
|
| 174 |
+
server_name="0.0.0.0", server_port=port,
|
| 175 |
+
share=os.environ.get("GRADIO_SHARE", "0") == "1",
|
| 176 |
+
)
|
assets/silence_latent_frame.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f73746d2163f8f1742c5de89005404ccaeeff05154bbb10a3337bf9bd13f161c
|
| 3 |
+
size 1501
|
configs/training_args.example.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example DramaBox IC-LoRA training config. Used by scripts/train.sh.
|
| 2 |
+
|
| 3 |
+
# Where to load preprocessed `audio_latents/` + `conditions/` shards from.
|
| 4 |
+
data_dir:
|
| 5 |
+
- /path/to/preprocessed_dataset_a/
|
| 6 |
+
- /path/to/preprocessed_dataset_b/
|
| 7 |
+
|
| 8 |
+
# One index file per data_dir entry. Each line:
|
| 9 |
+
# <sample_id>~<speaker_id>~<lang>~<sample_rate>~<offset>~<duration>~<phonemes>~<text>
|
| 10 |
+
speaker_index:
|
| 11 |
+
- /path/to/preprocessed_dataset_a/index.txt
|
| 12 |
+
- /path/to/preprocessed_dataset_b/index.txt
|
| 13 |
+
|
| 14 |
+
# Output directory (relative is fine — resolved against the repo root).
|
| 15 |
+
output_dir: tts_iclora_v1
|
| 16 |
+
|
| 17 |
+
# LTX-2.3 22B base. Same file is used for the transformer + the aux stack
|
| 18 |
+
# (PromptEncoder, AudioVAE, AudioDecoder).
|
| 19 |
+
checkpoint: ltx-2.3-22b-dev.safetensors
|
| 20 |
+
full_checkpoint: ltx-2.3-22b-dev.safetensors
|
| 21 |
+
base_model: dev
|
| 22 |
+
|
| 23 |
+
# LoRA hyperparams. rank == alpha is the simplest setup (scale = 1.0).
|
| 24 |
+
lora_rank: 128
|
| 25 |
+
lora_alpha: 128
|
| 26 |
+
lora_dropout: 0.1
|
| 27 |
+
|
| 28 |
+
# Voice-cloning ref-token settings.
|
| 29 |
+
ref_ratio: 0.3 # fraction of training samples that get a ref token
|
| 30 |
+
max_ref_tokens: 200 # max ref-token positions appended to target
|
| 31 |
+
|
| 32 |
+
text_dropout: 0.4 # CFG training: drop the text prompt with prob 0.4
|
| 33 |
+
|
| 34 |
+
# Schedule. Use lr_scheduler=constant with a small lr (1e-5) for a "fine-tune"
|
| 35 |
+
# resume; cosine + larger lr (1e-4) for from-scratch.
|
| 36 |
+
steps: 10000
|
| 37 |
+
lr: 1.0e-04
|
| 38 |
+
lr_scheduler: cosine
|
| 39 |
+
warmup_steps: 500
|
| 40 |
+
|
| 41 |
+
batch_size: 1
|
| 42 |
+
grad_accum: 4
|
| 43 |
+
max_grad_norm: 1.0
|
| 44 |
+
|
| 45 |
+
save_every: 500
|
| 46 |
+
log_every: 50
|
| 47 |
+
seed: 53
|
| 48 |
+
|
| 49 |
+
# (Optional) per-checkpoint validation eval — see configs/val_config.example.yaml
|
| 50 |
+
# val_config: val_config.example.yaml
|
| 51 |
+
|
| 52 |
+
# (Optional) resume from a previous LoRA adapter file:
|
| 53 |
+
# resume_lora: tts_iclora_v0/lora_step_05000.safetensors
|
ltx2/ltx_core/__init__.py
ADDED
|
File without changes
|
ltx2/ltx_core/batch_split.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Batch-splitting adapter for the transformer.
|
| 2 |
+
Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs
|
| 3 |
+
into smaller chunks before forwarding, then concatenates the results. This
|
| 4 |
+
controls peak activation memory at the cost of more forward passes.
|
| 5 |
+
The adapter is transparent — it has the same ``forward`` signature as
|
| 6 |
+
``X0Model`` and proxies attribute access to the wrapped model.
|
| 7 |
+
Example
|
| 8 |
+
-------
|
| 9 |
+
>>> from ltx_core.batch_split import BatchSplitAdapter
|
| 10 |
+
>>> adapter = BatchSplitAdapter(model, max_batch_size=1)
|
| 11 |
+
>>> # Receives B=4, runs 4xB=1 internally, returns B=4
|
| 12 |
+
>>> denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
|
| 23 |
+
from ltx_core.model.transformer.modality import Modality
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]:
|
| 27 |
+
"""Split a ``BatchedPerturbationConfig`` along the batch dimension."""
|
| 28 |
+
it = iter(config.perturbations)
|
| 29 |
+
return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None:
|
| 33 |
+
"""Concatenate tensors along batch dim, or return None if all are None."""
|
| 34 |
+
non_none = [t for t in tensors if t is not None]
|
| 35 |
+
if not non_none:
|
| 36 |
+
return None
|
| 37 |
+
return torch.cat(non_none, dim=0)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class BatchSplitAdapter(nn.Module):
|
| 41 |
+
"""Wraps a model and splits batched forward calls into smaller chunks.
|
| 42 |
+
Has the same ``forward`` signature as ``X0Model``:
|
| 43 |
+
``(video, audio, perturbations) -> (denoised_video, denoised_audio)``.
|
| 44 |
+
Args:
|
| 45 |
+
model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.).
|
| 46 |
+
max_batch_size: Maximum batch size per forward pass. Input batches
|
| 47 |
+
larger than this are split into sequential chunks.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, model: nn.Module, max_batch_size: int) -> None:
|
| 51 |
+
if max_batch_size < 1:
|
| 52 |
+
raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}")
|
| 53 |
+
super().__init__()
|
| 54 |
+
self._model = model
|
| 55 |
+
self._max_batch_size = max_batch_size
|
| 56 |
+
|
| 57 |
+
def _get_chunk_sizes(self, batch_size: int) -> list[int]:
|
| 58 |
+
full, remainder = divmod(batch_size, self._max_batch_size)
|
| 59 |
+
sizes = [self._max_batch_size] * full
|
| 60 |
+
if remainder:
|
| 61 |
+
sizes.append(remainder)
|
| 62 |
+
return sizes
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
video: Modality | None,
|
| 67 |
+
audio: Modality | None,
|
| 68 |
+
perturbations: BatchedPerturbationConfig,
|
| 69 |
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 70 |
+
batch_size = (video or audio).latent.shape[0]
|
| 71 |
+
|
| 72 |
+
if batch_size <= self._max_batch_size:
|
| 73 |
+
return self._model(video=video, audio=audio, perturbations=perturbations)
|
| 74 |
+
|
| 75 |
+
sizes = self._get_chunk_sizes(batch_size)
|
| 76 |
+
n = len(sizes)
|
| 77 |
+
|
| 78 |
+
v_chunks = video.split(sizes) if video is not None else [None] * n
|
| 79 |
+
a_chunks = audio.split(sizes) if audio is not None else [None] * n
|
| 80 |
+
p_chunks = _split_perturbations(perturbations, sizes)
|
| 81 |
+
|
| 82 |
+
chunk_results = [
|
| 83 |
+
self._model(video=vc, audio=ac, perturbations=pc)
|
| 84 |
+
for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True)
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
results_v, results_a = zip(*chunk_results, strict=True)
|
| 88 |
+
return _merge_tensors(list(results_v)), _merge_tensors(list(results_a))
|
| 89 |
+
|
| 90 |
+
def __getattr__(self, name: str) -> Any: # noqa: ANN401
|
| 91 |
+
"""Proxy attribute access to the wrapped model."""
|
| 92 |
+
try:
|
| 93 |
+
return super().__getattr__(name)
|
| 94 |
+
except AttributeError:
|
| 95 |
+
return getattr(self._model, name)
|
ltx2/ltx_core/components/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Diffusion pipeline components.
|
| 3 |
+
Submodules:
|
| 4 |
+
diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
|
| 5 |
+
guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
|
| 6 |
+
noisers - Noise samplers (GaussianNoiser)
|
| 7 |
+
patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
|
| 8 |
+
protocols - Protocol definitions (Patchifier, etc.)
|
| 9 |
+
schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
|
| 10 |
+
"""
|
ltx2/ltx_core/components/diffusion_steps.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.protocols import DiffusionStepProtocol
|
| 4 |
+
from ltx_core.utils import to_velocity
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EulerDiffusionStep(DiffusionStepProtocol):
|
| 8 |
+
"""
|
| 9 |
+
First-order Euler method for diffusion sampling.
|
| 10 |
+
Takes a single step from the current noise level (sigma) to the next by
|
| 11 |
+
computing velocity from the denoised prediction and applying: sample + velocity * dt.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def step(
|
| 15 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
sigma = sigmas[step_index]
|
| 18 |
+
sigma_next = sigmas[step_index + 1]
|
| 19 |
+
dt = sigma_next - sigma
|
| 20 |
+
velocity = to_velocity(sample, sigma, denoised_sample)
|
| 21 |
+
|
| 22 |
+
return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Res2sDiffusionStep(DiffusionStepProtocol):
|
| 26 |
+
"""
|
| 27 |
+
Second-order diffusion step for res_2s sampling with SDE noise injection.
|
| 28 |
+
Used by the res_2s denoising loop. Advances the sample from the current
|
| 29 |
+
sigma to the next by mixing a deterministic update (from the denoised
|
| 30 |
+
prediction) with injected noise via ``get_sde_coeff``, producing
|
| 31 |
+
variance-preserving transitions.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_sde_coeff(
|
| 36 |
+
sigma_next: torch.Tensor,
|
| 37 |
+
sigma_up: torch.Tensor | None = None,
|
| 38 |
+
sigma_down: torch.Tensor | None = None,
|
| 39 |
+
sigma_max: torch.Tensor | None = None,
|
| 40 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step.
|
| 43 |
+
Given either ``sigma_down`` or ``sigma_up``, returns the mixing
|
| 44 |
+
coefficients used for variance-preserving noise injection. If
|
| 45 |
+
``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are
|
| 46 |
+
derived; if ``sigma_down`` is provided, ``sigma_up`` and
|
| 47 |
+
``alpha_ratio`` are derived.
|
| 48 |
+
"""
|
| 49 |
+
if sigma_down is not None:
|
| 50 |
+
alpha_ratio = (1 - sigma_next) / (1 - sigma_down)
|
| 51 |
+
sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5
|
| 52 |
+
elif sigma_up is not None:
|
| 53 |
+
# Fallback to avoid sqrt(neg_num)
|
| 54 |
+
sigma_up.clamp_(max=sigma_next * 0.9999)
|
| 55 |
+
sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next)
|
| 56 |
+
sigma_signal = sigmax - sigma_next
|
| 57 |
+
sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5
|
| 58 |
+
alpha_ratio = sigma_signal + sigma_residual
|
| 59 |
+
sigma_down = sigma_residual / alpha_ratio
|
| 60 |
+
else:
|
| 61 |
+
alpha_ratio = torch.ones_like(sigma_next)
|
| 62 |
+
sigma_down = sigma_next
|
| 63 |
+
sigma_up = torch.zeros_like(sigma_next)
|
| 64 |
+
|
| 65 |
+
sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0)
|
| 66 |
+
# Replace NaNs in sigma_down with corresponding sigma_next elements (float32)
|
| 67 |
+
nan_mask = torch.isnan(sigma_down)
|
| 68 |
+
sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype)
|
| 69 |
+
alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0)
|
| 70 |
+
|
| 71 |
+
return alpha_ratio, sigma_down, sigma_up
|
| 72 |
+
|
| 73 |
+
def step(
|
| 74 |
+
self,
|
| 75 |
+
sample: torch.Tensor,
|
| 76 |
+
denoised_sample: torch.Tensor,
|
| 77 |
+
sigmas: torch.Tensor,
|
| 78 |
+
step_index: int,
|
| 79 |
+
noise: torch.Tensor,
|
| 80 |
+
eta: float = 0.5,
|
| 81 |
+
) -> torch.Tensor:
|
| 82 |
+
"""Advance one step with SDE noise injection via get_sde_coeff.
|
| 83 |
+
Args:
|
| 84 |
+
sample: Current noisy sample.
|
| 85 |
+
denoised_sample: Denoised prediction from the model.
|
| 86 |
+
sigmas: Noise schedule tensor.
|
| 87 |
+
step_index: Current step index in the schedule.
|
| 88 |
+
noise: Random noise tensor for stochastic injection.
|
| 89 |
+
eta: Controls stochastic noise injection strength (0=deterministic, 1=maximum). Default 0.5.
|
| 90 |
+
Returns:
|
| 91 |
+
Next sample with SDE noise injection applied.
|
| 92 |
+
"""
|
| 93 |
+
sigma = sigmas[step_index]
|
| 94 |
+
sigma_next = sigmas[step_index + 1]
|
| 95 |
+
alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * eta)
|
| 96 |
+
output_dtype = denoised_sample.dtype
|
| 97 |
+
if torch.any(sigma_up == 0) or torch.any(sigma_next == 0):
|
| 98 |
+
return denoised_sample
|
| 99 |
+
|
| 100 |
+
# Extract epsilon prediction
|
| 101 |
+
eps_next = (sample - denoised_sample) / (sigma - sigma_next)
|
| 102 |
+
denoised_next = sample - sigma * eps_next
|
| 103 |
+
|
| 104 |
+
# Mix deterministic and stochastic components
|
| 105 |
+
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise
|
| 106 |
+
return x_noised.to(output_dtype)
|
ltx2/ltx_core/components/guiders.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections.abc import Mapping, Sequence
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.components.protocols import GuiderProtocol
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class CFGGuider(GuiderProtocol):
|
| 12 |
+
"""
|
| 13 |
+
Classifier-free guidance (CFG) guider.
|
| 14 |
+
Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
|
| 15 |
+
denoising process toward the conditioned prediction.
|
| 16 |
+
Attributes:
|
| 17 |
+
scale: Guidance strength. 1.0 means no guidance, higher values increase
|
| 18 |
+
adherence to the conditioning.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
scale: float
|
| 22 |
+
|
| 23 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
return (self.scale - 1) * (cond - uncond)
|
| 25 |
+
|
| 26 |
+
def enabled(self) -> bool:
|
| 27 |
+
return self.scale != 1.0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True)
|
| 31 |
+
class CFGStarRescalingGuider(GuiderProtocol):
|
| 32 |
+
"""
|
| 33 |
+
Calculates the CFG delta between conditioned and unconditioned samples.
|
| 34 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 35 |
+
conditioning axis within the distribution, the unconditioned sample is
|
| 36 |
+
rescaled in accordance with the norm of the conditioned sample.
|
| 37 |
+
Attributes:
|
| 38 |
+
scale (float):
|
| 39 |
+
Global guidance strength. A value of 1.0 corresponds to no extra
|
| 40 |
+
guidance beyond the base model prediction. Values > 1.0 increase
|
| 41 |
+
the influence of the conditioned sample relative to the
|
| 42 |
+
unconditioned one.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
scale: float
|
| 46 |
+
|
| 47 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
rescaled_neg = projection_coef(cond, uncond) * uncond
|
| 49 |
+
return (self.scale - 1) * (cond - rescaled_neg)
|
| 50 |
+
|
| 51 |
+
def enabled(self) -> bool:
|
| 52 |
+
return self.scale != 1.0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass(frozen=True)
|
| 56 |
+
class STGGuider(GuiderProtocol):
|
| 57 |
+
"""
|
| 58 |
+
Calculates the STG delta between conditioned and perturbed denoised samples.
|
| 59 |
+
Perturbed samples are the result of the denoising process with perturbations,
|
| 60 |
+
e.g. attentions acting as passthrough for certain layers and modalities.
|
| 61 |
+
Attributes:
|
| 62 |
+
scale (float):
|
| 63 |
+
Global strength of the STG guidance. A value of 0.0 disables the
|
| 64 |
+
guidance. Larger values increase the correction applied in the
|
| 65 |
+
direction of (pos_denoised - perturbed_denoised).
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
scale: float
|
| 69 |
+
|
| 70 |
+
def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
return self.scale * (pos_denoised - perturbed_denoised)
|
| 72 |
+
|
| 73 |
+
def enabled(self) -> bool:
|
| 74 |
+
return self.scale != 0.0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass(frozen=True)
|
| 78 |
+
class LtxAPGGuider(GuiderProtocol):
|
| 79 |
+
"""
|
| 80 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 81 |
+
and unconditioned samples.
|
| 82 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 83 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 84 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 85 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 86 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 87 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 88 |
+
Attributes:
|
| 89 |
+
scale (float):
|
| 90 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 91 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 92 |
+
directions that change semantics but stay consistent with the
|
| 93 |
+
conditioning manifold.
|
| 94 |
+
eta (float):
|
| 95 |
+
Weight of the component of the guidance that is parallel to the
|
| 96 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 97 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 98 |
+
motion along the conditioning direction.
|
| 99 |
+
norm_threshold (float):
|
| 100 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 101 |
+
can be reduced or ignored (depending on implementation).
|
| 102 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 103 |
+
guidance signal is very small.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
scale: float
|
| 107 |
+
eta: float = 1.0
|
| 108 |
+
norm_threshold: float = 0.0
|
| 109 |
+
|
| 110 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
guidance = cond - uncond
|
| 112 |
+
if self.norm_threshold > 0:
|
| 113 |
+
ones = torch.ones_like(guidance)
|
| 114 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 115 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 116 |
+
guidance = guidance * scale_factor
|
| 117 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 118 |
+
g_parallel = proj_coeff * cond
|
| 119 |
+
g_orth = guidance - g_parallel
|
| 120 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 121 |
+
|
| 122 |
+
return g_apg * (self.scale - 1)
|
| 123 |
+
|
| 124 |
+
def enabled(self) -> bool:
|
| 125 |
+
return self.scale != 1.0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass(frozen=False)
|
| 129 |
+
class LegacyStatefulAPGGuider(GuiderProtocol):
|
| 130 |
+
"""
|
| 131 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 132 |
+
and unconditioned samples.
|
| 133 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 134 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 135 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 136 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 137 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 138 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 139 |
+
Attributes:
|
| 140 |
+
scale (float):
|
| 141 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 142 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 143 |
+
directions that change semantics but stay consistent with the
|
| 144 |
+
conditioning manifold.
|
| 145 |
+
eta (float):
|
| 146 |
+
Weight of the component of the guidance that is parallel to the
|
| 147 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 148 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 149 |
+
motion along the conditioning direction.
|
| 150 |
+
norm_threshold (float):
|
| 151 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 152 |
+
can be reduced or ignored (depending on implementation).
|
| 153 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 154 |
+
guidance signal is very small.
|
| 155 |
+
momentum (float):
|
| 156 |
+
Exponential moving-average coefficient for accumulating guidance
|
| 157 |
+
over time. running_avg = momentum * running_avg + guidance
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
scale: float
|
| 161 |
+
eta: float
|
| 162 |
+
norm_threshold: float = 5.0
|
| 163 |
+
momentum: float = 0.0
|
| 164 |
+
# it is user's responsibility not to use same APGGuider for several denoisings or different modalities
|
| 165 |
+
# in order not to share accumulated average across different denoisings or modalities
|
| 166 |
+
running_avg: torch.Tensor | None = None
|
| 167 |
+
|
| 168 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
guidance = cond - uncond
|
| 170 |
+
if self.momentum != 0:
|
| 171 |
+
if self.running_avg is None:
|
| 172 |
+
self.running_avg = guidance.clone()
|
| 173 |
+
else:
|
| 174 |
+
self.running_avg = self.momentum * self.running_avg + guidance
|
| 175 |
+
guidance = self.running_avg
|
| 176 |
+
|
| 177 |
+
if self.norm_threshold > 0:
|
| 178 |
+
ones = torch.ones_like(guidance)
|
| 179 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 180 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 181 |
+
guidance = guidance * scale_factor
|
| 182 |
+
|
| 183 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 184 |
+
g_parallel = proj_coeff * cond
|
| 185 |
+
g_orth = guidance - g_parallel
|
| 186 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 187 |
+
|
| 188 |
+
return g_apg * self.scale
|
| 189 |
+
|
| 190 |
+
def enabled(self) -> bool:
|
| 191 |
+
return self.scale != 0.0
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@dataclass(frozen=True)
|
| 195 |
+
class MultiModalGuiderParams:
|
| 196 |
+
"""
|
| 197 |
+
Parameters for the multi-modal guider.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
cfg_scale: float = 1.0
|
| 201 |
+
"CFG (Classifier-free guidance) scale controlling how strongly the model adheres to the prompt."
|
| 202 |
+
stg_scale: float = 0.0
|
| 203 |
+
"STG (Spatio-Temporal Guidance) scale controls how strongly the model reacts to the perturbation of the modality."
|
| 204 |
+
stg_blocks: list[int] | None = field(default_factory=list)
|
| 205 |
+
"Which transformer blocks to perturb for STG."
|
| 206 |
+
rescale_scale: float = 0.0
|
| 207 |
+
"Rescale scale controlling how strongly the model rescales the modality after applying other guidance."
|
| 208 |
+
modality_scale: float = 1.0
|
| 209 |
+
"Modality scale controlling how strongly the model reacts to the perturbation of the modality."
|
| 210 |
+
cfg_clamp_scale: float = 0.0
|
| 211 |
+
"Clamp guided prediction std to this multiple of conditioned prediction std. 0 = disabled."
|
| 212 |
+
skip_step: int = 0
|
| 213 |
+
"Skip step controlling how often the model skips the step."
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _params_for_sigma_from_sorted_dict(
|
| 217 |
+
sigma: float, params_by_sigma: Sequence[tuple[float, MultiModalGuiderParams]]
|
| 218 |
+
) -> MultiModalGuiderParams:
|
| 219 |
+
"""
|
| 220 |
+
Return params for the given sigma from a sorted (sigma_upper_bound -> params) structure.
|
| 221 |
+
Keys are sorted descending (bin upper bounds). Bin i is (key_{i+1}, key_i].
|
| 222 |
+
Get all keys >= sigma; use last in list (smallest such key = upper bound of bin containing sigma),
|
| 223 |
+
or last entry in the sequence if list is empty (sigma above max key).
|
| 224 |
+
"""
|
| 225 |
+
if not params_by_sigma:
|
| 226 |
+
raise ValueError("params_by_sigma must be non-empty")
|
| 227 |
+
sigma = float(sigma)
|
| 228 |
+
keys_desc = [k for k, _ in params_by_sigma]
|
| 229 |
+
keys_ge_sigma = [k for k in keys_desc if k >= sigma]
|
| 230 |
+
# sigma above all keys: use first bin (max key)
|
| 231 |
+
key = keys_ge_sigma[-1] if keys_ge_sigma else keys_desc[0]
|
| 232 |
+
return next(p for k, p in params_by_sigma if k == key)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@dataclass(frozen=True)
|
| 236 |
+
class MultiModalGuider:
|
| 237 |
+
"""
|
| 238 |
+
Multi-modal guider with constant params per instance.
|
| 239 |
+
For sigma-dependent params, use MultiModalGuiderFactory.build_from_sigma(sigma) to
|
| 240 |
+
obtain a guider for each step.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
params: MultiModalGuiderParams
|
| 244 |
+
negative_context: torch.Tensor | None = None
|
| 245 |
+
|
| 246 |
+
def calculate(
|
| 247 |
+
self,
|
| 248 |
+
cond: torch.Tensor,
|
| 249 |
+
uncond_text: torch.Tensor | float,
|
| 250 |
+
uncond_perturbed: torch.Tensor | float,
|
| 251 |
+
uncond_modality: torch.Tensor | float,
|
| 252 |
+
) -> torch.Tensor:
|
| 253 |
+
"""
|
| 254 |
+
The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg,
|
| 255 |
+
and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned
|
| 256 |
+
prediction.
|
| 257 |
+
"""
|
| 258 |
+
pred = (
|
| 259 |
+
cond
|
| 260 |
+
+ (self.params.cfg_scale - 1) * (cond - uncond_text)
|
| 261 |
+
+ self.params.stg_scale * (cond - uncond_perturbed)
|
| 262 |
+
+ (self.params.modality_scale - 1) * (cond - uncond_modality)
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if self.params.rescale_scale != 0:
|
| 266 |
+
factor = cond.std() / pred.std()
|
| 267 |
+
factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale)
|
| 268 |
+
pred = pred * factor
|
| 269 |
+
|
| 270 |
+
# Clamp guided prediction to prevent trajectory overshoot.
|
| 271 |
+
# Instead of global std (which averages over all tokens), clamp per-token.
|
| 272 |
+
# This catches individual tokens that overshoot even if the global std looks fine.
|
| 273 |
+
if self.params.cfg_clamp_scale > 0:
|
| 274 |
+
cfg_delta = pred - cond
|
| 275 |
+
# Per-token magnitude clamping
|
| 276 |
+
delta_norm = cfg_delta.norm(dim=-1, keepdim=True) # [B, T, 1]
|
| 277 |
+
cond_norm = cond.norm(dim=-1, keepdim=True)
|
| 278 |
+
max_norm = cond_norm * self.params.cfg_clamp_scale
|
| 279 |
+
# Clamp tokens where delta exceeds max
|
| 280 |
+
scale = torch.where(
|
| 281 |
+
delta_norm > max_norm,
|
| 282 |
+
max_norm / delta_norm.clamp(min=1e-8),
|
| 283 |
+
torch.ones_like(delta_norm),
|
| 284 |
+
)
|
| 285 |
+
pred = cond + cfg_delta * scale
|
| 286 |
+
|
| 287 |
+
return pred
|
| 288 |
+
|
| 289 |
+
def do_unconditional_generation(self) -> bool:
|
| 290 |
+
"""Returns True if the guider is doing unconditional generation."""
|
| 291 |
+
return not math.isclose(self.params.cfg_scale, 1.0)
|
| 292 |
+
|
| 293 |
+
def do_perturbed_generation(self) -> bool:
|
| 294 |
+
"""Returns True if the guider is doing perturbed generation."""
|
| 295 |
+
return not math.isclose(self.params.stg_scale, 0.0)
|
| 296 |
+
|
| 297 |
+
def do_isolated_modality_generation(self) -> bool:
|
| 298 |
+
"""Returns True if the guider is doing isolated modality generation."""
|
| 299 |
+
return not math.isclose(self.params.modality_scale, 1.0)
|
| 300 |
+
|
| 301 |
+
def should_skip_step(self, step: int) -> bool:
|
| 302 |
+
"""Returns True if the guider should skip the step."""
|
| 303 |
+
if self.params.skip_step == 0:
|
| 304 |
+
return False
|
| 305 |
+
return step % (self.params.skip_step + 1) != 0
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@dataclass(frozen=True)
|
| 309 |
+
class MultiModalGuiderFactory:
|
| 310 |
+
"""
|
| 311 |
+
Factory that creates a MultiModalGuider for a given sigma.
|
| 312 |
+
Single source of truth: _params_by_sigma (schedule). Use constant() for
|
| 313 |
+
one params for all sigma, from_dict() for sigma-binned params.
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
negative_context: torch.Tensor | None = None
|
| 317 |
+
_params_by_sigma: tuple[tuple[float, MultiModalGuiderParams], ...] = ()
|
| 318 |
+
|
| 319 |
+
@classmethod
|
| 320 |
+
def constant(
|
| 321 |
+
cls,
|
| 322 |
+
params: MultiModalGuiderParams,
|
| 323 |
+
negative_context: torch.Tensor | None = None,
|
| 324 |
+
) -> "MultiModalGuiderFactory":
|
| 325 |
+
"""Build a factory with constant params (same guider for all sigma)."""
|
| 326 |
+
return cls(
|
| 327 |
+
negative_context=negative_context,
|
| 328 |
+
_params_by_sigma=((float("inf"), params),),
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def from_dict(
|
| 333 |
+
cls,
|
| 334 |
+
sigma_to_params: Mapping[float, MultiModalGuiderParams],
|
| 335 |
+
negative_context: torch.Tensor | None = None,
|
| 336 |
+
) -> "MultiModalGuiderFactory":
|
| 337 |
+
"""
|
| 338 |
+
Build a factory from a dict of sigma_value -> MultiModalGuiderParams.
|
| 339 |
+
Keys are sorted descending and used for bin lookup in params(sigma).
|
| 340 |
+
"""
|
| 341 |
+
if not sigma_to_params:
|
| 342 |
+
raise ValueError("sigma_to_params must be non-empty")
|
| 343 |
+
sorted_items = tuple(sorted(sigma_to_params.items(), key=lambda x: x[0], reverse=True))
|
| 344 |
+
return cls(negative_context=negative_context, _params_by_sigma=sorted_items)
|
| 345 |
+
|
| 346 |
+
def params(self, sigma: float | torch.Tensor) -> MultiModalGuiderParams:
|
| 347 |
+
"""Return params effective for the given sigma (getter; single source of truth)."""
|
| 348 |
+
sigma_val = float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma)
|
| 349 |
+
return _params_for_sigma_from_sorted_dict(sigma_val, self._params_by_sigma)
|
| 350 |
+
|
| 351 |
+
def build_from_sigma(self, sigma: float | torch.Tensor) -> MultiModalGuider:
|
| 352 |
+
"""Return a MultiModalGuider with params effective for the given sigma."""
|
| 353 |
+
return MultiModalGuider(
|
| 354 |
+
params=self.params(sigma),
|
| 355 |
+
negative_context=self.negative_context,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def create_multimodal_guider_factory(
|
| 360 |
+
params: MultiModalGuiderParams | MultiModalGuiderFactory,
|
| 361 |
+
negative_context: torch.Tensor | None = None,
|
| 362 |
+
) -> MultiModalGuiderFactory:
|
| 363 |
+
"""
|
| 364 |
+
Create or return a MultiModalGuiderFactory. Pass constant params for a
|
| 365 |
+
single-params factory (uses MultiModalGuiderFactory.constant), or an existing
|
| 366 |
+
MultiModalGuiderFactory. When given a factory, returns it as-is unless
|
| 367 |
+
negative_context is provided. For sigma-dependent params use
|
| 368 |
+
MultiModalGuiderFactory.from_dict(...) and pass that as params.
|
| 369 |
+
"""
|
| 370 |
+
if isinstance(params, MultiModalGuiderFactory):
|
| 371 |
+
if negative_context is not None and params.negative_context is not negative_context:
|
| 372 |
+
return MultiModalGuiderFactory.from_dict(dict(params._params_by_sigma), negative_context=negative_context)
|
| 373 |
+
return params
|
| 374 |
+
return MultiModalGuiderFactory.constant(params, negative_context=negative_context)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
|
| 378 |
+
batch_size = to_project.shape[0]
|
| 379 |
+
positive_flat = to_project.reshape(batch_size, -1)
|
| 380 |
+
negative_flat = project_onto.reshape(batch_size, -1)
|
| 381 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 382 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
| 383 |
+
return dot_product / squared_norm
|
ltx2/ltx_core/components/noisers.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
from typing import Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Noiser(Protocol):
|
| 10 |
+
"""Protocol for adding noise to a latent state during diffusion."""
|
| 11 |
+
|
| 12 |
+
def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaussianNoiser(Noiser):
|
| 16 |
+
"""Adds Gaussian noise to a latent state, scaled by the denoise mask."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator: torch.Generator):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.generator = generator
|
| 22 |
+
|
| 23 |
+
def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
|
| 24 |
+
noise = torch.randn(
|
| 25 |
+
*latent_state.latent.shape,
|
| 26 |
+
device=latent_state.latent.device,
|
| 27 |
+
dtype=latent_state.latent.dtype,
|
| 28 |
+
generator=self.generator,
|
| 29 |
+
)
|
| 30 |
+
scaled_mask = latent_state.denoise_mask * noise_scale
|
| 31 |
+
latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
|
| 32 |
+
return replace(
|
| 33 |
+
latent_state,
|
| 34 |
+
latent=latent.to(latent_state.latent.dtype),
|
| 35 |
+
)
|
ltx2/ltx_core/components/patchifiers.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.components.protocols import Patchifier
|
| 8 |
+
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VideoLatentPatchifier(Patchifier):
|
| 12 |
+
def __init__(self, patch_size: int):
|
| 13 |
+
# Patch sizes for video latents.
|
| 14 |
+
self._patch_size = (
|
| 15 |
+
1, # temporal dimension
|
| 16 |
+
patch_size, # height dimension
|
| 17 |
+
patch_size, # width dimension
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 22 |
+
return self._patch_size
|
| 23 |
+
|
| 24 |
+
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
| 25 |
+
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
| 26 |
+
|
| 27 |
+
def patchify(
|
| 28 |
+
self,
|
| 29 |
+
latents: torch.Tensor,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
latents = einops.rearrange(
|
| 32 |
+
latents,
|
| 33 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 34 |
+
p1=self._patch_size[0],
|
| 35 |
+
p2=self._patch_size[1],
|
| 36 |
+
p3=self._patch_size[2],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return latents
|
| 40 |
+
|
| 41 |
+
def unpatchify(
|
| 42 |
+
self,
|
| 43 |
+
latents: torch.Tensor,
|
| 44 |
+
output_shape: VideoLatentShape,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
| 47 |
+
|
| 48 |
+
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
| 49 |
+
patch_grid_height = output_shape.height // self._patch_size[1]
|
| 50 |
+
patch_grid_width = output_shape.width // self._patch_size[2]
|
| 51 |
+
|
| 52 |
+
latents = einops.rearrange(
|
| 53 |
+
latents,
|
| 54 |
+
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
| 55 |
+
f=patch_grid_frames,
|
| 56 |
+
h=patch_grid_height,
|
| 57 |
+
w=patch_grid_width,
|
| 58 |
+
p=self._patch_size[1],
|
| 59 |
+
q=self._patch_size[2],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return latents
|
| 63 |
+
|
| 64 |
+
def get_patch_grid_bounds(
|
| 65 |
+
self,
|
| 66 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 67 |
+
device: Optional[torch.device] = None,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
| 71 |
+
patch produced by `patchify`. The bounds are expressed in the original
|
| 72 |
+
video grid coordinates: frame/time, height, and width.
|
| 73 |
+
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
| 74 |
+
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
| 75 |
+
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
| 76 |
+
Args:
|
| 77 |
+
output_shape: Video grid description containing frames, height, and width.
|
| 78 |
+
device: Device of the latent tensor.
|
| 79 |
+
"""
|
| 80 |
+
if not isinstance(output_shape, VideoLatentShape):
|
| 81 |
+
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
| 82 |
+
|
| 83 |
+
frames = output_shape.frames
|
| 84 |
+
height = output_shape.height
|
| 85 |
+
width = output_shape.width
|
| 86 |
+
batch_size = output_shape.batch
|
| 87 |
+
|
| 88 |
+
# Validate inputs to ensure positive dimensions
|
| 89 |
+
assert frames > 0, f"frames must be positive, got {frames}"
|
| 90 |
+
assert height > 0, f"height must be positive, got {height}"
|
| 91 |
+
assert width > 0, f"width must be positive, got {width}"
|
| 92 |
+
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
| 93 |
+
|
| 94 |
+
# Generate grid coordinates for each dimension (frame, height, width)
|
| 95 |
+
# We use torch.arange to create the starting coordinates for each patch.
|
| 96 |
+
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
| 97 |
+
grid_coords = torch.meshgrid(
|
| 98 |
+
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
| 99 |
+
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
| 100 |
+
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
| 101 |
+
indexing="ij",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Stack the grid coordinates to create the start coordinates tensor.
|
| 105 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 106 |
+
patch_starts = torch.stack(grid_coords, dim=0)
|
| 107 |
+
|
| 108 |
+
# Create a tensor containing the size of a single patch:
|
| 109 |
+
# (frame_patch_size, height_patch_size, width_patch_size).
|
| 110 |
+
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
| 111 |
+
patch_size_delta = torch.tensor(
|
| 112 |
+
self._patch_size,
|
| 113 |
+
device=patch_starts.device,
|
| 114 |
+
dtype=patch_starts.dtype,
|
| 115 |
+
).view(3, 1, 1, 1)
|
| 116 |
+
|
| 117 |
+
# Calculate end coordinates: start + patch_size
|
| 118 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 119 |
+
patch_ends = patch_starts + patch_size_delta
|
| 120 |
+
|
| 121 |
+
# Stack start and end coordinates together along the last dimension
|
| 122 |
+
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
| 123 |
+
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
| 124 |
+
|
| 125 |
+
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
| 126 |
+
# Final Shape: (batch_size, 3, num_patches, 2)
|
| 127 |
+
latent_coords = einops.repeat(
|
| 128 |
+
latent_coords,
|
| 129 |
+
"c f h w bounds -> b c (f h w) bounds",
|
| 130 |
+
b=batch_size,
|
| 131 |
+
bounds=2,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return latent_coords
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_pixel_coords(
|
| 138 |
+
latent_coords: torch.Tensor,
|
| 139 |
+
scale_factors: SpatioTemporalScaleFactors,
|
| 140 |
+
causal_fix: bool = False,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
"""
|
| 143 |
+
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
| 144 |
+
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
| 145 |
+
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
| 146 |
+
Args:
|
| 147 |
+
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
| 148 |
+
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
| 149 |
+
per axis.
|
| 150 |
+
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
| 151 |
+
that treat frame zero differently still yield non-negative timestamps.
|
| 152 |
+
"""
|
| 153 |
+
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
| 154 |
+
broadcast_shape = [1] * latent_coords.ndim
|
| 155 |
+
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
| 156 |
+
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
| 157 |
+
|
| 158 |
+
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
| 159 |
+
pixel_coords = latent_coords * scale_tensor
|
| 160 |
+
|
| 161 |
+
if causal_fix:
|
| 162 |
+
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
| 163 |
+
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
| 164 |
+
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
| 165 |
+
|
| 166 |
+
return pixel_coords
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AudioPatchifier(Patchifier):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
patch_size: int,
|
| 173 |
+
sample_rate: int = 16000,
|
| 174 |
+
hop_length: int = 160,
|
| 175 |
+
audio_latent_downsample_factor: int = 4,
|
| 176 |
+
is_causal: bool = True,
|
| 177 |
+
shift: int = 0,
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Patchifier tailored for spectrogram/audio latents.
|
| 181 |
+
Args:
|
| 182 |
+
patch_size: Number of mel bins combined into a single patch. This
|
| 183 |
+
controls the resolution along the frequency axis.
|
| 184 |
+
sample_rate: Original waveform sampling rate. Used to map latent
|
| 185 |
+
indices back to seconds so downstream consumers can align audio
|
| 186 |
+
and video cues.
|
| 187 |
+
hop_length: Window hop length used for the spectrogram. Determines
|
| 188 |
+
how many real-time samples separate two consecutive latent frames.
|
| 189 |
+
audio_latent_downsample_factor: Ratio between spectrogram frames and
|
| 190 |
+
latent frames; compensates for additional downsampling inside the
|
| 191 |
+
VAE encoder.
|
| 192 |
+
is_causal: When True, timing is shifted to account for causal
|
| 193 |
+
receptive fields so timestamps do not peek into the future.
|
| 194 |
+
shift: Integer offset applied to the latent indices. Enables
|
| 195 |
+
constructing overlapping windows from the same latent sequence.
|
| 196 |
+
"""
|
| 197 |
+
self.hop_length = hop_length
|
| 198 |
+
self.sample_rate = sample_rate
|
| 199 |
+
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
| 200 |
+
self.is_causal = is_causal
|
| 201 |
+
self.shift = shift
|
| 202 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 206 |
+
return self._patch_size
|
| 207 |
+
|
| 208 |
+
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
|
| 209 |
+
return tgt_shape.frames
|
| 210 |
+
|
| 211 |
+
def _get_audio_latent_time_in_sec(
|
| 212 |
+
self,
|
| 213 |
+
start_latent: int,
|
| 214 |
+
end_latent: int,
|
| 215 |
+
dtype: torch.dtype,
|
| 216 |
+
device: Optional[torch.device] = None,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Converts latent indices into real-time seconds while honoring causal
|
| 220 |
+
offsets and the configured hop length.
|
| 221 |
+
Args:
|
| 222 |
+
start_latent: Inclusive start index inside the latent sequence. This
|
| 223 |
+
sets the first timestamp returned.
|
| 224 |
+
end_latent: Exclusive end index. Determines how many timestamps get
|
| 225 |
+
generated.
|
| 226 |
+
dtype: Floating-point dtype used for the returned tensor, allowing
|
| 227 |
+
callers to control precision.
|
| 228 |
+
device: Target device for the timestamp tensor. When omitted the
|
| 229 |
+
computation occurs on CPU to avoid surprising GPU allocations.
|
| 230 |
+
"""
|
| 231 |
+
if device is None:
|
| 232 |
+
device = torch.device("cpu")
|
| 233 |
+
|
| 234 |
+
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
| 235 |
+
|
| 236 |
+
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
| 237 |
+
|
| 238 |
+
if self.is_causal:
|
| 239 |
+
# Frame offset for causal alignment.
|
| 240 |
+
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
|
| 241 |
+
causal_offset = 1
|
| 242 |
+
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
|
| 243 |
+
|
| 244 |
+
return audio_mel_frame * self.hop_length / self.sample_rate
|
| 245 |
+
|
| 246 |
+
def _compute_audio_timings(
|
| 247 |
+
self,
|
| 248 |
+
batch_size: int,
|
| 249 |
+
num_steps: int,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
"""
|
| 253 |
+
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
|
| 254 |
+
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
|
| 255 |
+
Args:
|
| 256 |
+
batch_size: Number of sequences to broadcast the timings over.
|
| 257 |
+
num_steps: Number of latent frames (time steps) to convert into timestamps.
|
| 258 |
+
device: Device on which the resulting tensor should reside.
|
| 259 |
+
"""
|
| 260 |
+
resolved_device = device
|
| 261 |
+
if resolved_device is None:
|
| 262 |
+
resolved_device = torch.device("cpu")
|
| 263 |
+
|
| 264 |
+
start_timings = self._get_audio_latent_time_in_sec(
|
| 265 |
+
self.shift,
|
| 266 |
+
num_steps + self.shift,
|
| 267 |
+
torch.float32,
|
| 268 |
+
resolved_device,
|
| 269 |
+
)
|
| 270 |
+
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 271 |
+
|
| 272 |
+
end_timings = self._get_audio_latent_time_in_sec(
|
| 273 |
+
self.shift + 1,
|
| 274 |
+
num_steps + self.shift + 1,
|
| 275 |
+
torch.float32,
|
| 276 |
+
resolved_device,
|
| 277 |
+
)
|
| 278 |
+
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 279 |
+
|
| 280 |
+
return torch.stack([start_timings, end_timings], dim=-1)
|
| 281 |
+
|
| 282 |
+
def patchify(
|
| 283 |
+
self,
|
| 284 |
+
audio_latents: torch.Tensor,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
"""
|
| 287 |
+
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
|
| 288 |
+
to derive timestamps for each latent frame based on the configured hop
|
| 289 |
+
length and downsampling.
|
| 290 |
+
Args:
|
| 291 |
+
audio_latents: Latent tensor to patchify.
|
| 292 |
+
Returns:
|
| 293 |
+
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
|
| 294 |
+
corresponding timing metadata when needed.
|
| 295 |
+
"""
|
| 296 |
+
audio_latents = einops.rearrange(
|
| 297 |
+
audio_latents,
|
| 298 |
+
"b c t f -> b t (c f)",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return audio_latents
|
| 302 |
+
|
| 303 |
+
def unpatchify(
|
| 304 |
+
self,
|
| 305 |
+
audio_latents: torch.Tensor,
|
| 306 |
+
output_shape: AudioLatentShape,
|
| 307 |
+
) -> torch.Tensor:
|
| 308 |
+
"""
|
| 309 |
+
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
|
| 310 |
+
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
|
| 311 |
+
frame's position in real time.
|
| 312 |
+
Args:
|
| 313 |
+
audio_latents: Latent tensor to unpatchify.
|
| 314 |
+
output_shape: Shape of the unpatched output tensor.
|
| 315 |
+
Returns:
|
| 316 |
+
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
|
| 317 |
+
metadata associated with the restored latents.
|
| 318 |
+
"""
|
| 319 |
+
# audio_latents shape: (batch, time, freq * channels)
|
| 320 |
+
audio_latents = einops.rearrange(
|
| 321 |
+
audio_latents,
|
| 322 |
+
"b t (c f) -> b c t f",
|
| 323 |
+
c=output_shape.channels,
|
| 324 |
+
f=output_shape.mel_bins,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return audio_latents
|
| 328 |
+
|
| 329 |
+
def get_patch_grid_bounds(
|
| 330 |
+
self,
|
| 331 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 332 |
+
device: Optional[torch.device] = None,
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
"""
|
| 335 |
+
Return the temporal bounds `[inclusive start, exclusive end)` for every
|
| 336 |
+
patch emitted by `patchify`. For audio this corresponds to timestamps in
|
| 337 |
+
seconds aligned with the original spectrogram grid.
|
| 338 |
+
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
|
| 339 |
+
- axis 1 (size 1) represents the temporal dimension
|
| 340 |
+
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
|
| 341 |
+
Args:
|
| 342 |
+
output_shape: Audio grid specification describing the number of time steps.
|
| 343 |
+
device: Target device for the returned tensor.
|
| 344 |
+
"""
|
| 345 |
+
if not isinstance(output_shape, AudioLatentShape):
|
| 346 |
+
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
|
| 347 |
+
|
| 348 |
+
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
|
ltx2/ltx_core/components/protocols.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.types import AudioLatentShape, VideoLatentShape
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Patchifier(Protocol):
|
| 9 |
+
"""
|
| 10 |
+
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def patchify(
|
| 14 |
+
self,
|
| 15 |
+
latents: torch.Tensor,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
...
|
| 18 |
+
"""
|
| 19 |
+
Convert latent tensors into flattened patch tokens.
|
| 20 |
+
Args:
|
| 21 |
+
latents: Latent tensor to patchify.
|
| 22 |
+
Returns:
|
| 23 |
+
Flattened patch tokens tensor.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def unpatchify(
|
| 27 |
+
self,
|
| 28 |
+
latents: torch.Tensor,
|
| 29 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
| 33 |
+
Args:
|
| 34 |
+
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
| 35 |
+
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
| 36 |
+
VideoLatentShape.
|
| 37 |
+
Returns:
|
| 38 |
+
Dense latent tensor restored from the flattened representation.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 43 |
+
...
|
| 44 |
+
"""
|
| 45 |
+
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def get_patch_grid_bounds(
|
| 49 |
+
self,
|
| 50 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 51 |
+
device: torch.device | None = None,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
...
|
| 54 |
+
"""
|
| 55 |
+
Compute metadata describing where each latent patch resides within the
|
| 56 |
+
grid specified by `output_shape`.
|
| 57 |
+
Args:
|
| 58 |
+
output_shape: Target grid layout for the patches.
|
| 59 |
+
device: Target device for the returned tensor.
|
| 60 |
+
Returns:
|
| 61 |
+
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SchedulerProtocol(Protocol):
|
| 66 |
+
"""
|
| 67 |
+
Protocol for schedulers that provide a sigmas schedule tensor for a
|
| 68 |
+
given number of steps. Device is cpu.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class GuiderProtocol(Protocol):
|
| 75 |
+
"""
|
| 76 |
+
Protocol for guiders that compute a delta tensor given conditioning inputs.
|
| 77 |
+
The returned delta should be added to the conditional output (cond), enabling
|
| 78 |
+
multiple guiders to be chained together by accumulating their deltas.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
scale: float
|
| 82 |
+
|
| 83 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
|
| 84 |
+
|
| 85 |
+
def enabled(self) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
|
| 88 |
+
is 1.0.
|
| 89 |
+
"""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DiffusionStepProtocol(Protocol):
|
| 94 |
+
"""
|
| 95 |
+
Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
|
| 96 |
+
current denoised sample tensor, and sigmas tensor.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def step(
|
| 100 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs
|
| 101 |
+
) -> torch.Tensor: ...
|
ltx2/ltx_core/components/schedulers.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
import numpy
|
| 5 |
+
import scipy
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.components.protocols import SchedulerProtocol
|
| 9 |
+
|
| 10 |
+
BASE_SHIFT_ANCHOR = 1024
|
| 11 |
+
MAX_SHIFT_ANCHOR = 4096
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LTX2Scheduler(SchedulerProtocol):
|
| 15 |
+
"""
|
| 16 |
+
Default scheduler for LTX-2 diffusion sampling.
|
| 17 |
+
Generates a sigma schedule with token-count-dependent shifting and optional
|
| 18 |
+
stretching to a terminal value.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def execute(
|
| 22 |
+
self,
|
| 23 |
+
steps: int,
|
| 24 |
+
latent: torch.Tensor | None = None,
|
| 25 |
+
max_shift: float = 2.05,
|
| 26 |
+
base_shift: float = 0.95,
|
| 27 |
+
stretch: bool = True,
|
| 28 |
+
terminal: float = 0.1,
|
| 29 |
+
default_number_of_tokens: int = MAX_SHIFT_ANCHOR,
|
| 30 |
+
**_kwargs,
|
| 31 |
+
) -> torch.FloatTensor:
|
| 32 |
+
tokens = math.prod(latent.shape[2:]) if latent is not None else default_number_of_tokens
|
| 33 |
+
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
| 34 |
+
|
| 35 |
+
x1 = BASE_SHIFT_ANCHOR
|
| 36 |
+
x2 = MAX_SHIFT_ANCHOR
|
| 37 |
+
mm = (max_shift - base_shift) / (x2 - x1)
|
| 38 |
+
b = base_shift - mm * x1
|
| 39 |
+
sigma_shift = (tokens) * mm + b
|
| 40 |
+
|
| 41 |
+
power = 1
|
| 42 |
+
sigmas = torch.where(
|
| 43 |
+
sigmas != 0,
|
| 44 |
+
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
| 45 |
+
0,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Stretch sigmas so that its final value matches the given terminal value.
|
| 49 |
+
if stretch:
|
| 50 |
+
non_zero_mask = sigmas != 0
|
| 51 |
+
non_zero_sigmas = sigmas[non_zero_mask]
|
| 52 |
+
one_minus_z = 1.0 - non_zero_sigmas
|
| 53 |
+
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
| 54 |
+
stretched = 1.0 - (one_minus_z / scale_factor)
|
| 55 |
+
sigmas[non_zero_mask] = stretched
|
| 56 |
+
|
| 57 |
+
return sigmas.to(torch.float32)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LinearQuadraticScheduler(SchedulerProtocol):
|
| 61 |
+
"""
|
| 62 |
+
Scheduler with linear steps followed by quadratic steps.
|
| 63 |
+
Produces a sigma schedule that transitions linearly up to a threshold,
|
| 64 |
+
then follows a quadratic curve for the remaining steps.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def execute(
|
| 68 |
+
self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
|
| 69 |
+
) -> torch.FloatTensor:
|
| 70 |
+
if steps == 1:
|
| 71 |
+
return torch.FloatTensor([1.0, 0.0])
|
| 72 |
+
|
| 73 |
+
if linear_steps is None:
|
| 74 |
+
linear_steps = steps // 2
|
| 75 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
| 76 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
| 77 |
+
quadratic_steps = steps - linear_steps
|
| 78 |
+
quadratic_sigma_schedule = []
|
| 79 |
+
if quadratic_steps > 0:
|
| 80 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
| 81 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
| 82 |
+
const = quadratic_coef * (linear_steps**2)
|
| 83 |
+
quadratic_sigma_schedule = [
|
| 84 |
+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
|
| 85 |
+
]
|
| 86 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
| 87 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
| 88 |
+
return torch.FloatTensor(sigma_schedule)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BetaScheduler(SchedulerProtocol):
|
| 92 |
+
"""
|
| 93 |
+
Scheduler using a beta distribution to sample timesteps.
|
| 94 |
+
Based on: https://arxiv.org/abs/2407.12173
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
shift = 2.37
|
| 98 |
+
timesteps_length = 10000
|
| 99 |
+
|
| 100 |
+
def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
|
| 101 |
+
"""
|
| 102 |
+
Execute the beta scheduler.
|
| 103 |
+
Args:
|
| 104 |
+
steps: The number of steps to execute the scheduler for.
|
| 105 |
+
alpha: The alpha parameter for the beta distribution.
|
| 106 |
+
beta: The beta parameter for the beta distribution.
|
| 107 |
+
Warnings:
|
| 108 |
+
The number of steps within `sigmas` theoretically might be less than `steps+1`,
|
| 109 |
+
because of the deduplication of the identical timesteps
|
| 110 |
+
Returns:
|
| 111 |
+
A tensor of sigmas.
|
| 112 |
+
"""
|
| 113 |
+
model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
|
| 114 |
+
total_timesteps = len(model_sampling_sigmas) - 1
|
| 115 |
+
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
| 116 |
+
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
|
| 117 |
+
ts = list(dict.fromkeys(ts))
|
| 118 |
+
|
| 119 |
+
sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
|
| 120 |
+
return torch.FloatTensor(sigmas)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@lru_cache(maxsize=5)
|
| 124 |
+
def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
|
| 125 |
+
timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
|
| 126 |
+
return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def flux_time_shift(mu: float, sigma: float, t: float) -> float:
|
| 130 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
ltx2/ltx_core/conditioning/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning utilities: latent state, tools, and conditioning types."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.exceptions import ConditioningError
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.conditioning.types import (
|
| 6 |
+
ConditioningItemAttentionStrengthWrapper,
|
| 7 |
+
VideoConditionByKeyframeIndex,
|
| 8 |
+
VideoConditionByLatentIndex,
|
| 9 |
+
VideoConditionByReferenceLatent,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"ConditioningError",
|
| 14 |
+
"ConditioningItem",
|
| 15 |
+
"ConditioningItemAttentionStrengthWrapper",
|
| 16 |
+
"VideoConditionByKeyframeIndex",
|
| 17 |
+
"VideoConditionByLatentIndex",
|
| 18 |
+
"VideoConditionByReferenceLatent",
|
| 19 |
+
]
|
ltx2/ltx_core/conditioning/exceptions.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ConditioningError(Exception):
|
| 2 |
+
"""
|
| 3 |
+
Class for conditioning-related errors.
|
| 4 |
+
"""
|
ltx2/ltx_core/conditioning/item.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol
|
| 2 |
+
|
| 3 |
+
from ltx_core.tools import LatentTools
|
| 4 |
+
from ltx_core.types import LatentState
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConditioningItem(Protocol):
|
| 8 |
+
"""Protocol for conditioning items that modify latent state during diffusion."""
|
| 9 |
+
|
| 10 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 11 |
+
"""
|
| 12 |
+
Apply the conditioning to the latent state.
|
| 13 |
+
Args:
|
| 14 |
+
latent_state: The latent state to apply the conditioning to. This is state always patchified.
|
| 15 |
+
Returns:
|
| 16 |
+
The latent state after the conditioning has been applied.
|
| 17 |
+
IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
|
| 18 |
+
latent.
|
| 19 |
+
"""
|
| 20 |
+
...
|
ltx2/ltx_core/conditioning/mask_utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for building 2D self-attention masks for conditioning items."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from ltx_core.types import LatentState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resolve_cross_mask(
|
| 14 |
+
attention_mask: float | int | torch.Tensor,
|
| 15 |
+
num_new_tokens: int,
|
| 16 |
+
batch_size: int,
|
| 17 |
+
device: torch.device,
|
| 18 |
+
dtype: torch.dtype,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""Convert an attention_mask (scalar or tensor) to a (B, M) cross_mask tensor.
|
| 21 |
+
Args:
|
| 22 |
+
attention_mask: Scalar value applied uniformly, 1D tensor of shape (M,)
|
| 23 |
+
broadcast across batch, or 2D tensor of shape (B, M).
|
| 24 |
+
num_new_tokens: Number of new conditioning tokens M.
|
| 25 |
+
batch_size: Batch size B.
|
| 26 |
+
device: Device for the output tensor.
|
| 27 |
+
dtype: Data type for the output tensor.
|
| 28 |
+
Returns:
|
| 29 |
+
Cross-mask tensor of shape (B, M).
|
| 30 |
+
"""
|
| 31 |
+
if isinstance(attention_mask, (int, float)):
|
| 32 |
+
return torch.full(
|
| 33 |
+
(batch_size, num_new_tokens),
|
| 34 |
+
fill_value=float(attention_mask),
|
| 35 |
+
device=device,
|
| 36 |
+
dtype=dtype,
|
| 37 |
+
)
|
| 38 |
+
mask = attention_mask.to(device=device, dtype=dtype)
|
| 39 |
+
|
| 40 |
+
# Handle scalar (0-D) tensor like a Python scalar.
|
| 41 |
+
if mask.dim() == 0:
|
| 42 |
+
return torch.full(
|
| 43 |
+
(batch_size, num_new_tokens),
|
| 44 |
+
fill_value=float(mask.item()),
|
| 45 |
+
device=device,
|
| 46 |
+
dtype=dtype,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if mask.dim() == 1:
|
| 50 |
+
if mask.shape[0] != num_new_tokens:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"1-D attention_mask length must equal num_new_tokens ({num_new_tokens}), got shape {tuple(mask.shape)}"
|
| 53 |
+
)
|
| 54 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1)
|
| 55 |
+
elif mask.dim() == 2:
|
| 56 |
+
b, m = mask.shape
|
| 57 |
+
if m != num_new_tokens:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"2-D attention_mask second dimension must equal num_new_tokens ({num_new_tokens}), "
|
| 60 |
+
f"got shape {tuple(mask.shape)}"
|
| 61 |
+
)
|
| 62 |
+
if b not in (batch_size, 1):
|
| 63 |
+
raise ValueError(
|
| 64 |
+
f"2-D attention_mask batch dimension must equal batch_size ({batch_size}) or 1, "
|
| 65 |
+
f"got shape {tuple(mask.shape)}"
|
| 66 |
+
)
|
| 67 |
+
if b == 1 and batch_size > 1:
|
| 68 |
+
mask = mask.expand(batch_size, -1)
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"attention_mask tensor must be 0-D, 1-D, or 2-D, got {mask.dim()}-D with shape {tuple(mask.shape)}"
|
| 72 |
+
)
|
| 73 |
+
return mask
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def update_attention_mask(
|
| 77 |
+
latent_state: LatentState,
|
| 78 |
+
attention_mask: float | torch.Tensor | None,
|
| 79 |
+
num_noisy_tokens: int,
|
| 80 |
+
num_new_tokens: int,
|
| 81 |
+
batch_size: int,
|
| 82 |
+
device: torch.device,
|
| 83 |
+
dtype: torch.dtype,
|
| 84 |
+
) -> torch.Tensor | None:
|
| 85 |
+
"""Build or update the self-attention mask for newly appended conditioning tokens.
|
| 86 |
+
If *attention_mask* is ``None`` and no existing mask is present, returns
|
| 87 |
+
``None``. If *attention_mask* is ``None`` but an existing mask is present,
|
| 88 |
+
the mask is expanded with full attention (1s) for the new tokens so that
|
| 89 |
+
its dimensions stay consistent with the growing latent sequence. Otherwise,
|
| 90 |
+
resolves *attention_mask* to a per-token cross-mask and expands the 2-D
|
| 91 |
+
attention mask via :func:`build_attention_mask`.
|
| 92 |
+
Args:
|
| 93 |
+
latent_state: Current latent state (provides the existing mask and total
|
| 94 |
+
existing-token count).
|
| 95 |
+
attention_mask: Per-token attention weight. Scalar, 1-D ``(M,)``, 2-D
|
| 96 |
+
``(B, M)`` tensor, or ``None`` (no-op).
|
| 97 |
+
num_noisy_tokens: Number of original noisy tokens (from
|
| 98 |
+
``latent_tools.target_shape.token_count()``).
|
| 99 |
+
num_new_tokens: Number of new conditioning tokens being appended.
|
| 100 |
+
batch_size: Batch size.
|
| 101 |
+
device: Device for the output tensor.
|
| 102 |
+
dtype: Data type for the output tensor.
|
| 103 |
+
Returns:
|
| 104 |
+
Updated attention mask of shape ``(B, N+M, N+M)``, or ``None`` if no
|
| 105 |
+
masking is needed.
|
| 106 |
+
"""
|
| 107 |
+
if attention_mask is None:
|
| 108 |
+
if latent_state.attention_mask is None:
|
| 109 |
+
return None
|
| 110 |
+
# Existing mask present but no new mask requested: pad with 1s (full
|
| 111 |
+
# attention) so the mask dimensions stay consistent with the growing
|
| 112 |
+
# latent sequence.
|
| 113 |
+
cross_mask = torch.ones(batch_size, num_new_tokens, device=device, dtype=dtype)
|
| 114 |
+
return build_attention_mask(
|
| 115 |
+
existing_mask=latent_state.attention_mask,
|
| 116 |
+
num_noisy_tokens=num_noisy_tokens,
|
| 117 |
+
num_new_tokens=num_new_tokens,
|
| 118 |
+
num_existing_tokens=latent_state.latent.shape[1],
|
| 119 |
+
cross_mask=cross_mask,
|
| 120 |
+
device=device,
|
| 121 |
+
dtype=dtype,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
cross_mask = resolve_cross_mask(attention_mask, num_new_tokens, batch_size, device, dtype)
|
| 125 |
+
return build_attention_mask(
|
| 126 |
+
existing_mask=latent_state.attention_mask,
|
| 127 |
+
num_noisy_tokens=num_noisy_tokens,
|
| 128 |
+
num_new_tokens=num_new_tokens,
|
| 129 |
+
num_existing_tokens=latent_state.latent.shape[1],
|
| 130 |
+
cross_mask=cross_mask,
|
| 131 |
+
device=device,
|
| 132 |
+
dtype=dtype,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def build_attention_mask(
|
| 137 |
+
existing_mask: torch.Tensor | None,
|
| 138 |
+
num_noisy_tokens: int,
|
| 139 |
+
num_new_tokens: int,
|
| 140 |
+
num_existing_tokens: int,
|
| 141 |
+
cross_mask: torch.Tensor,
|
| 142 |
+
device: torch.device,
|
| 143 |
+
dtype: torch.dtype,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Expand the attention mask to include newly appended conditioning tokens.
|
| 147 |
+
Each conditioning item appends M new reference tokens to the sequence. This function
|
| 148 |
+
builds a (B, N+M, N+M) attention mask with the following block structure:
|
| 149 |
+
noisy prev_ref new_ref
|
| 150 |
+
(N_noisy) (N-N_noisy) (M)
|
| 151 |
+
┌───────────┬───────────┬───────────┐
|
| 152 |
+
noisy │ │ │ │
|
| 153 |
+
(N_noisy) │ existing │ existing │ cross │
|
| 154 |
+
│ │ │ │
|
| 155 |
+
├───────────┼───────────┼───────────┤
|
| 156 |
+
prev_ref │ │ │ │
|
| 157 |
+
(N-N_noisy)│ existing │ existing │ 0 │
|
| 158 |
+
│ │ │ │
|
| 159 |
+
├───────────┼───────────┼───────────┤
|
| 160 |
+
new_ref │ │ │ │
|
| 161 |
+
(M) │ cross │ 0 │ 1 │
|
| 162 |
+
│ │ │ │
|
| 163 |
+
└───────────┴───────────┴───────────┘
|
| 164 |
+
Where:
|
| 165 |
+
- **existing**: preserved from the previous mask (or 1.0 if first conditioning)
|
| 166 |
+
- **cross**: values from *cross_mask* (shape B, M), in [0, 1]
|
| 167 |
+
- **0**: no attention between different reference groups
|
| 168 |
+
Args:
|
| 169 |
+
existing_mask: Current attention mask of shape (B, N, N), or None if no mask exists yet.
|
| 170 |
+
When None, the top-left NxN block is filled with 1s (full attention between all
|
| 171 |
+
existing tokens including any prior reference tokens that had no mask).
|
| 172 |
+
num_noisy_tokens: Number of original noisy tokens (always at positions [0:num_noisy_tokens]).
|
| 173 |
+
num_new_tokens: Number of new conditioning tokens M being appended.
|
| 174 |
+
num_existing_tokens: Total number of current tokens N (noisy + any prior conditioning tokens).
|
| 175 |
+
cross_mask: Per-token attention weight of shape (B, M) controlling attention between
|
| 176 |
+
new reference tokens and noisy tokens. Values in [0, 1].
|
| 177 |
+
device: Device for the output tensor.
|
| 178 |
+
dtype: Data type for the output tensor.
|
| 179 |
+
Returns:
|
| 180 |
+
Attention mask of shape (B, N+M, N+M) with values in [0, 1].
|
| 181 |
+
"""
|
| 182 |
+
batch_size = cross_mask.shape[0]
|
| 183 |
+
total = num_existing_tokens + num_new_tokens
|
| 184 |
+
|
| 185 |
+
# Start with zeros
|
| 186 |
+
mask = torch.zeros((batch_size, total, total), device=device, dtype=dtype)
|
| 187 |
+
|
| 188 |
+
# Top-left: preserve existing mask or fill with 1s for noisy tokens
|
| 189 |
+
if existing_mask is not None:
|
| 190 |
+
mask[:, :num_existing_tokens, :num_existing_tokens] = existing_mask
|
| 191 |
+
else:
|
| 192 |
+
mask[:, :num_existing_tokens, :num_existing_tokens] = 1.0
|
| 193 |
+
|
| 194 |
+
# Bottom-right: new reference tokens fully attend to themselves
|
| 195 |
+
mask[:, num_existing_tokens:, num_existing_tokens:] = 1.0
|
| 196 |
+
|
| 197 |
+
# Cross-attention between noisy tokens and new reference tokens
|
| 198 |
+
# cross_mask shape: (B, M) -> broadcast to (B, N_noisy, M) and (B, M, N_noisy)
|
| 199 |
+
|
| 200 |
+
# Noisy tokens attending to new reference tokens: [0:N_noisy, N:N+M]
|
| 201 |
+
# Each column j in this block gets cross_mask[:, j]
|
| 202 |
+
mask[:, :num_noisy_tokens, num_existing_tokens:] = cross_mask.unsqueeze(1)
|
| 203 |
+
|
| 204 |
+
# New reference tokens attending to noisy tokens: [N:N+M, 0:N_noisy]
|
| 205 |
+
# Each row i in this block gets cross_mask[:, i]
|
| 206 |
+
mask[:, num_existing_tokens:, :num_noisy_tokens] = cross_mask.unsqueeze(2)
|
| 207 |
+
|
| 208 |
+
# [N_noisy:N, N:N+M] and [N:N+M, N_noisy:N] remain 0 (no cross-ref attention)
|
| 209 |
+
|
| 210 |
+
return mask
|
ltx2/ltx_core/conditioning/types/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning type implementations."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.types.attention_strength_wrapper import ConditioningItemAttentionStrengthWrapper
|
| 4 |
+
from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
|
| 5 |
+
from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
|
| 6 |
+
from ltx_core.conditioning.types.reference_video_cond import VideoConditionByReferenceLatent
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"ConditioningItemAttentionStrengthWrapper",
|
| 10 |
+
"VideoConditionByKeyframeIndex",
|
| 11 |
+
"VideoConditionByLatentIndex",
|
| 12 |
+
"VideoConditionByReferenceLatent",
|
| 13 |
+
]
|
ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrapper conditioning item that adds attention masking to any inner conditioning."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import replace
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 8 |
+
from ltx_core.conditioning.mask_utils import update_attention_mask
|
| 9 |
+
from ltx_core.tools import LatentTools
|
| 10 |
+
from ltx_core.types import LatentState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConditioningItemAttentionStrengthWrapper(ConditioningItem):
|
| 14 |
+
"""Wraps a conditioning item to add an attention mask for its tokens.
|
| 15 |
+
Separates the *attention-masking* concern from the underlying conditioning
|
| 16 |
+
logic (token layout, positional encoding, denoise strength). The inner
|
| 17 |
+
conditioning item appends tokens to the latent sequence as usual, and this
|
| 18 |
+
wrapper then builds or updates the self-attention mask so that the newly
|
| 19 |
+
added tokens interact with the noisy tokens according to *attention_mask*.
|
| 20 |
+
Args:
|
| 21 |
+
conditioning: Any conditioning item that appends tokens to the latent.
|
| 22 |
+
attention_mask: Per-token attention weight controlling how strongly the
|
| 23 |
+
new conditioning tokens attend to/from noisy tokens. Can be a
|
| 24 |
+
scalar (float) applied uniformly, or a tensor of shape ``(B, M)``
|
| 25 |
+
for spatial control, where ``M = F * H * W`` is the number of
|
| 26 |
+
patchified conditioning tokens. Values in ``[0, 1]``.
|
| 27 |
+
Example::
|
| 28 |
+
cond = ConditioningItemAttentionStrengthWrapper(
|
| 29 |
+
VideoConditionByReferenceLatent(latent=ref, strength=1.0),
|
| 30 |
+
attention_mask=0.5,
|
| 31 |
+
)
|
| 32 |
+
state = cond.apply_to(latent_state, latent_tools)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
conditioning: ConditioningItem,
|
| 38 |
+
attention_mask: float | torch.Tensor,
|
| 39 |
+
):
|
| 40 |
+
self.conditioning = conditioning
|
| 41 |
+
self.attention_mask = attention_mask
|
| 42 |
+
|
| 43 |
+
def apply_to(
|
| 44 |
+
self,
|
| 45 |
+
latent_state: LatentState,
|
| 46 |
+
latent_tools: LatentTools,
|
| 47 |
+
) -> LatentState:
|
| 48 |
+
"""Apply inner conditioning, then build the attention mask for its tokens."""
|
| 49 |
+
# Snapshot the original state for mask building
|
| 50 |
+
original_state = latent_state
|
| 51 |
+
|
| 52 |
+
# Inner conditioning appends tokens (positions, denoise mask, etc.)
|
| 53 |
+
new_state = self.conditioning.apply_to(latent_state, latent_tools)
|
| 54 |
+
|
| 55 |
+
num_new_tokens = new_state.latent.shape[1] - original_state.latent.shape[1]
|
| 56 |
+
if num_new_tokens == 0:
|
| 57 |
+
return new_state
|
| 58 |
+
|
| 59 |
+
# Build the attention mask using the *original* state as the reference
|
| 60 |
+
# so that the block structure is computed correctly.
|
| 61 |
+
new_attention_mask = update_attention_mask(
|
| 62 |
+
latent_state=original_state,
|
| 63 |
+
attention_mask=self.attention_mask,
|
| 64 |
+
num_noisy_tokens=latent_tools.target_shape.token_count(),
|
| 65 |
+
num_new_tokens=num_new_tokens,
|
| 66 |
+
batch_size=new_state.latent.shape[0],
|
| 67 |
+
device=new_state.latent.device,
|
| 68 |
+
dtype=new_state.latent.dtype,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return replace(new_state, attention_mask=new_attention_mask)
|
ltx2/ltx_core/conditioning/types/keyframe_cond.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.patchifiers import get_pixel_coords
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.conditioning.mask_utils import update_attention_mask
|
| 6 |
+
from ltx_core.tools import VideoLatentTools
|
| 7 |
+
from ltx_core.types import LatentState, VideoLatentShape
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VideoConditionByKeyframeIndex(ConditioningItem):
|
| 11 |
+
"""
|
| 12 |
+
Conditions video generation on keyframe latents at a specific frame index.
|
| 13 |
+
Appends keyframe tokens to the latent state with positions offset by frame_idx,
|
| 14 |
+
and sets denoise strength according to the strength parameter.
|
| 15 |
+
To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
|
| 16 |
+
Args:
|
| 17 |
+
keyframes: Keyframe latents [B, C, F, H, W].
|
| 18 |
+
frame_idx: Frame index offset for positional encoding.
|
| 19 |
+
strength: Conditioning strength (1.0 = clean, 0.0 = fully denoised).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
|
| 23 |
+
self.keyframes = keyframes
|
| 24 |
+
self.frame_idx = frame_idx
|
| 25 |
+
self.strength = strength
|
| 26 |
+
|
| 27 |
+
def apply_to(
|
| 28 |
+
self,
|
| 29 |
+
latent_state: LatentState,
|
| 30 |
+
latent_tools: VideoLatentTools,
|
| 31 |
+
) -> LatentState:
|
| 32 |
+
tokens = latent_tools.patchifier.patchify(self.keyframes)
|
| 33 |
+
latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
|
| 34 |
+
output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
|
| 35 |
+
device=self.keyframes.device,
|
| 36 |
+
)
|
| 37 |
+
positions = get_pixel_coords(
|
| 38 |
+
latent_coords=latent_coords,
|
| 39 |
+
scale_factors=latent_tools.scale_factors,
|
| 40 |
+
causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
positions[:, 0, ...] += self.frame_idx
|
| 44 |
+
positions = positions.to(dtype=torch.float32)
|
| 45 |
+
positions[:, 0, ...] /= latent_tools.fps
|
| 46 |
+
|
| 47 |
+
denoise_mask = torch.full(
|
| 48 |
+
size=(*tokens.shape[:2], 1),
|
| 49 |
+
fill_value=1.0 - self.strength,
|
| 50 |
+
device=self.keyframes.device,
|
| 51 |
+
dtype=self.keyframes.dtype,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
new_attention_mask = update_attention_mask(
|
| 55 |
+
latent_state=latent_state,
|
| 56 |
+
attention_mask=None,
|
| 57 |
+
num_noisy_tokens=latent_tools.target_shape.token_count(),
|
| 58 |
+
num_new_tokens=tokens.shape[1],
|
| 59 |
+
batch_size=tokens.shape[0],
|
| 60 |
+
device=self.keyframes.device,
|
| 61 |
+
dtype=self.keyframes.dtype,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return LatentState(
|
| 65 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 66 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 67 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 68 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 69 |
+
attention_mask=new_attention_mask,
|
| 70 |
+
)
|
ltx2/ltx_core/conditioning/types/latent_cond.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.exceptions import ConditioningError
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.tools import LatentTools
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VideoConditionByLatentIndex(ConditioningItem):
|
| 10 |
+
"""
|
| 11 |
+
Conditions video generation by injecting latents at a specific latent frame index.
|
| 12 |
+
Replaces tokens in the latent state at positions corresponding to latent_idx,
|
| 13 |
+
and sets denoise strength according to the strength parameter.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
|
| 17 |
+
self.latent = latent
|
| 18 |
+
self.strength = strength
|
| 19 |
+
self.latent_idx = latent_idx
|
| 20 |
+
|
| 21 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 22 |
+
cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
|
| 23 |
+
tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
|
| 24 |
+
|
| 25 |
+
if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
|
| 26 |
+
raise ConditioningError(
|
| 27 |
+
f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
|
| 28 |
+
f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
|
| 29 |
+
"the image and latent have the same spatial shape."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 33 |
+
start_token = latent_tools.patchifier.get_token_count(
|
| 34 |
+
latent_tools.target_shape._replace(frames=self.latent_idx)
|
| 35 |
+
)
|
| 36 |
+
stop_token = start_token + tokens.shape[1]
|
| 37 |
+
|
| 38 |
+
latent_state = latent_state.clone()
|
| 39 |
+
|
| 40 |
+
latent_state.latent[:, start_token:stop_token] = tokens
|
| 41 |
+
latent_state.clean_latent[:, start_token:stop_token] = tokens
|
| 42 |
+
latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
|
| 43 |
+
|
| 44 |
+
return latent_state
|
ltx2/ltx_core/conditioning/types/noise_mask_cond.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.patchifiers import get_pixel_coords
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.tools import LatentTools, SpatioTemporalScaleFactors
|
| 6 |
+
from ltx_core.types import AudioLatentShape, LatentState, VideoLatentShape
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class TemporalRegionMask(ConditioningItem):
|
| 11 |
+
"""Conditioning item that sets ``denoise_mask = 0`` outside a time range
|
| 12 |
+
and ``1`` inside, so only the specified temporal region is regenerated.
|
| 13 |
+
Uses ``start_time`` and ``end_time`` in seconds. Works in *patchified*
|
| 14 |
+
(token) space using the patchifier's ``get_patch_grid_bounds``: for video
|
| 15 |
+
coords are latent frame indices (converted from seconds via ``fps``), for
|
| 16 |
+
audio coords are already in seconds.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
start_time: float # seconds, inclusive
|
| 20 |
+
end_time: float # seconds, exclusive
|
| 21 |
+
fps: float
|
| 22 |
+
|
| 23 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 24 |
+
coords = latent_tools.patchifier.get_patch_grid_bounds(
|
| 25 |
+
latent_tools.target_shape, device=latent_state.denoise_mask.device
|
| 26 |
+
)
|
| 27 |
+
if isinstance(latent_tools.target_shape, AudioLatentShape):
|
| 28 |
+
# Audio: patchifier get_patch_grid_bounds returns seconds
|
| 29 |
+
t_boundaries = coords[:, 0]
|
| 30 |
+
elif isinstance(latent_tools.target_shape, VideoLatentShape):
|
| 31 |
+
# Video: patchifier get_patch_grid_bounds returns latent bounds, converting to frame numbers & pixel bounds
|
| 32 |
+
scale_factors = getattr(latent_tools, "scale_factors", SpatioTemporalScaleFactors.default())
|
| 33 |
+
pixel_bounds = get_pixel_coords(coords, scale_factors, causal_fix=getattr(latent_tools, "causal_fix", True))
|
| 34 |
+
# converting frame numbers to seconds
|
| 35 |
+
t_boundaries = pixel_bounds[:, 0] / self.fps
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError("Unsupported LatentShape type, expected AudioLatentShape or VideoLatentShape")
|
| 38 |
+
t_start, t_end = t_boundaries.unbind(dim=-1) # [B, N]
|
| 39 |
+
in_region = (t_end > self.start_time) & (t_start < self.end_time)
|
| 40 |
+
state = latent_state.clone()
|
| 41 |
+
mask_val = in_region.to(state.denoise_mask.dtype)
|
| 42 |
+
if state.denoise_mask.dim() == 3:
|
| 43 |
+
mask_val = mask_val.unsqueeze(-1)
|
| 44 |
+
state.denoise_mask.copy_(mask_val)
|
| 45 |
+
return state
|
ltx2/ltx_core/conditioning/types/reference_video_cond.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reference video conditioning for IC-LoRA inference."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.components.patchifiers import get_pixel_coords
|
| 6 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 7 |
+
from ltx_core.conditioning.mask_utils import update_attention_mask
|
| 8 |
+
from ltx_core.tools import VideoLatentTools
|
| 9 |
+
from ltx_core.types import LatentState, VideoLatentShape
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VideoConditionByReferenceLatent(ConditioningItem):
|
| 13 |
+
"""
|
| 14 |
+
Conditions video generation on a reference video latent for IC-LoRA inference.
|
| 15 |
+
IC-LoRAs are trained by concatenating reference (control signal) and target tokens,
|
| 16 |
+
learning to attend across both. This class replicates that setup at inference by
|
| 17 |
+
appending reference tokens to the latent sequence.
|
| 18 |
+
IC-LoRAs can be trained with lower-resolution references than the target (e.g., 384px
|
| 19 |
+
reference for 768px output) for efficiency and better generalization. The
|
| 20 |
+
`downscale_factor` scales reference positions to match target coordinates, preserving
|
| 21 |
+
the learned positional relationships. This must match the factor used during training
|
| 22 |
+
(stored in LoRA metadata).
|
| 23 |
+
To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
|
| 24 |
+
Args:
|
| 25 |
+
latent: Reference video latents [B, C, F, H, W]
|
| 26 |
+
downscale_factor: Target/reference resolution ratio (e.g., 2 = half-resolution
|
| 27 |
+
reference). Spatial positions are scaled by this factor.
|
| 28 |
+
strength: Conditioning strength. 1.0 = full (reference kept clean),
|
| 29 |
+
0.0 = none (reference denoised). Default 1.0.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
latent: torch.Tensor,
|
| 35 |
+
downscale_factor: int = 1,
|
| 36 |
+
strength: float = 1.0,
|
| 37 |
+
):
|
| 38 |
+
self.latent = latent
|
| 39 |
+
self.downscale_factor = downscale_factor
|
| 40 |
+
self.strength = strength
|
| 41 |
+
|
| 42 |
+
def apply_to(
|
| 43 |
+
self,
|
| 44 |
+
latent_state: LatentState,
|
| 45 |
+
latent_tools: VideoLatentTools,
|
| 46 |
+
) -> LatentState:
|
| 47 |
+
"""Append reference video tokens with scaled positions."""
|
| 48 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 49 |
+
|
| 50 |
+
# Compute positions for the reference video's actual dimensions
|
| 51 |
+
latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
|
| 52 |
+
output_shape=VideoLatentShape.from_torch_shape(self.latent.shape),
|
| 53 |
+
device=self.latent.device,
|
| 54 |
+
)
|
| 55 |
+
positions = get_pixel_coords(
|
| 56 |
+
latent_coords=latent_coords,
|
| 57 |
+
scale_factors=latent_tools.scale_factors,
|
| 58 |
+
causal_fix=latent_tools.causal_fix,
|
| 59 |
+
)
|
| 60 |
+
positions = positions.to(dtype=torch.float32)
|
| 61 |
+
positions[:, 0, ...] /= latent_tools.fps
|
| 62 |
+
|
| 63 |
+
# Scale spatial positions to match target coordinate space
|
| 64 |
+
if self.downscale_factor != 1:
|
| 65 |
+
positions[:, 1, ...] *= self.downscale_factor # height axis
|
| 66 |
+
positions[:, 2, ...] *= self.downscale_factor # width axis
|
| 67 |
+
|
| 68 |
+
denoise_mask = torch.full(
|
| 69 |
+
size=(*tokens.shape[:2], 1),
|
| 70 |
+
fill_value=1.0 - self.strength,
|
| 71 |
+
device=self.latent.device,
|
| 72 |
+
dtype=self.latent.dtype,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
new_attention_mask = update_attention_mask(
|
| 76 |
+
latent_state=latent_state,
|
| 77 |
+
attention_mask=None,
|
| 78 |
+
num_noisy_tokens=latent_tools.target_shape.token_count(),
|
| 79 |
+
num_new_tokens=tokens.shape[1],
|
| 80 |
+
batch_size=tokens.shape[0],
|
| 81 |
+
device=self.latent.device,
|
| 82 |
+
dtype=self.latent.dtype,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return LatentState(
|
| 86 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 87 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 88 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 89 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 90 |
+
attention_mask=new_attention_mask,
|
| 91 |
+
)
|
ltx2/ltx_core/guidance/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Guidance and perturbation utilities for attention manipulation."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.guidance.perturbations import (
|
| 4 |
+
BatchedPerturbationConfig,
|
| 5 |
+
Perturbation,
|
| 6 |
+
PerturbationConfig,
|
| 7 |
+
PerturbationType,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"BatchedPerturbationConfig",
|
| 12 |
+
"Perturbation",
|
| 13 |
+
"PerturbationConfig",
|
| 14 |
+
"PerturbationType",
|
| 15 |
+
]
|
ltx2/ltx_core/guidance/perturbations.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._prims_common import DeviceLikeType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PerturbationType(Enum):
|
| 9 |
+
"""Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
|
| 10 |
+
|
| 11 |
+
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
| 12 |
+
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
| 13 |
+
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
| 14 |
+
SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class Perturbation:
|
| 19 |
+
"""A single perturbation specifying which attention type to skip and in which blocks."""
|
| 20 |
+
|
| 21 |
+
type: PerturbationType
|
| 22 |
+
blocks: list[int] | None # None means all blocks
|
| 23 |
+
|
| 24 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 25 |
+
if self.type != perturbation_type:
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
if self.blocks is None:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
return block in self.blocks
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class PerturbationConfig:
|
| 36 |
+
"""Configuration holding a list of perturbations for a single sample."""
|
| 37 |
+
|
| 38 |
+
perturbations: list[Perturbation] | None
|
| 39 |
+
|
| 40 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 41 |
+
if self.perturbations is None:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def empty() -> "PerturbationConfig":
|
| 48 |
+
return PerturbationConfig([])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass(frozen=True)
|
| 52 |
+
class BatchedPerturbationConfig:
|
| 53 |
+
"""Perturbation configurations for a batch, with utilities for generating attention masks."""
|
| 54 |
+
|
| 55 |
+
perturbations: list[PerturbationConfig]
|
| 56 |
+
|
| 57 |
+
def mask(
|
| 58 |
+
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
| 61 |
+
for batch_idx, perturbation in enumerate(self.perturbations):
|
| 62 |
+
if perturbation.is_perturbed(perturbation_type, block):
|
| 63 |
+
mask[batch_idx] = 0
|
| 64 |
+
|
| 65 |
+
return mask
|
| 66 |
+
|
| 67 |
+
def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
mask = self.mask(perturbation_type, block, values.device, values.dtype)
|
| 69 |
+
return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
|
| 70 |
+
|
| 71 |
+
def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 72 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 73 |
+
|
| 74 |
+
def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 75 |
+
return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def empty(batch_size: int) -> "BatchedPerturbationConfig":
|
| 79 |
+
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
ltx2/ltx_core/layer_streaming.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Layer streaming wrapper for memory-efficient inference.
|
| 2 |
+
Keeps most transformer/decoder layers on CPU pinned memory and streams them
|
| 3 |
+
to GPU on demand, using a secondary CUDA stream to prefetch upcoming layers
|
| 4 |
+
so that data transfer overlaps with compute.
|
| 5 |
+
General-purpose: works with any ``nn.Module`` whose forward iterates over a
|
| 6 |
+
``nn.ModuleList`` attribute (e.g. ``transformer_blocks``, ``layers``).
|
| 7 |
+
Each layer is evicted back to CPU immediately after its forward completes,
|
| 8 |
+
and prefetch uses modular indexing so the last layer's prefetch wraps around
|
| 9 |
+
to prepare early layers for the next forward pass.
|
| 10 |
+
Example
|
| 11 |
+
-------
|
| 12 |
+
>>> model = build_my_model(device=torch.device("cpu"))
|
| 13 |
+
>>> model = LayerStreamingWrapper(
|
| 14 |
+
... model,
|
| 15 |
+
... layers_attr="transformer_blocks",
|
| 16 |
+
... target_device=torch.device("cuda:0"),
|
| 17 |
+
... prefetch_count=2,
|
| 18 |
+
... )
|
| 19 |
+
>>> out = model(inputs) # hooks handle layer streaming
|
| 20 |
+
>>> model.teardown() # move everything back to CPU
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import functools
|
| 26 |
+
import itertools
|
| 27 |
+
import logging
|
| 28 |
+
from typing import Any
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from torch import nn
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _resolve_attr(module: nn.Module, dotted_path: str) -> nn.ModuleList:
|
| 37 |
+
"""Resolve a dotted attribute path like ``'model.language_model.layers'``."""
|
| 38 |
+
obj: Any = module
|
| 39 |
+
for part in dotted_path.split("."):
|
| 40 |
+
obj = getattr(obj, part)
|
| 41 |
+
if not isinstance(obj, nn.ModuleList):
|
| 42 |
+
raise TypeError(f"Expected nn.ModuleList at '{dotted_path}', got {type(obj).__name__}")
|
| 43 |
+
return obj
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class _LayerStore:
|
| 47 |
+
"""Manages on-demand pinning of layer parameters for GPU streaming.
|
| 48 |
+
Stores references to each layer's source data (which may be file-backed
|
| 49 |
+
mmap views or in-memory tensors). When a layer needs to be transferred
|
| 50 |
+
to GPU, its source data is pinned on demand and copied; on eviction the
|
| 51 |
+
pinned copy is freed and the source data is restored.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, layers: nn.ModuleList, target_device: torch.device) -> None:
|
| 55 |
+
self.target_device = target_device
|
| 56 |
+
self.num_layers = len(layers)
|
| 57 |
+
self._on_gpu: set[int] = set()
|
| 58 |
+
|
| 59 |
+
# Keep a reference to the source data for each layer so we can pin it
|
| 60 |
+
# on demand and restore it after eviction.
|
| 61 |
+
self._source_data: list[dict[str, torch.Tensor]] = []
|
| 62 |
+
for layer in layers:
|
| 63 |
+
source: dict[str, torch.Tensor] = {}
|
| 64 |
+
for name, tensor in itertools.chain(layer.named_parameters(), layer.named_buffers()):
|
| 65 |
+
source[name] = tensor.data
|
| 66 |
+
self._source_data.append(source)
|
| 67 |
+
|
| 68 |
+
# Hold pinned tensors alive until the H2D transfer completes.
|
| 69 |
+
# Without this, the CachingHostAllocator can reclaim a pinned tensor
|
| 70 |
+
# as soon as its Python reference is dropped, even if an async H2D
|
| 71 |
+
# transfer is still reading from it.
|
| 72 |
+
self._pinned_in_flight: dict[int, list[torch.Tensor]] = {}
|
| 73 |
+
|
| 74 |
+
def _check_idx(self, idx: int) -> None:
|
| 75 |
+
if idx < 0 or idx >= self.num_layers:
|
| 76 |
+
raise IndexError(f"Layer index {idx} out of range [0, {self.num_layers})")
|
| 77 |
+
|
| 78 |
+
def is_on_gpu(self, idx: int) -> bool:
|
| 79 |
+
return idx in self._on_gpu
|
| 80 |
+
|
| 81 |
+
def move_to_gpu(self, idx: int, layer: nn.Module, *, non_blocking: bool = False) -> None:
|
| 82 |
+
"""Pin layer *idx* on demand, then transfer to GPU."""
|
| 83 |
+
self._check_idx(idx)
|
| 84 |
+
if idx in self._on_gpu:
|
| 85 |
+
return
|
| 86 |
+
source = self._source_data[idx]
|
| 87 |
+
pinned_refs: list[torch.Tensor] = []
|
| 88 |
+
for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()):
|
| 89 |
+
pinned = source[name].pin_memory()
|
| 90 |
+
param.data = pinned.to(self.target_device, non_blocking=non_blocking)
|
| 91 |
+
pinned_refs.append(pinned)
|
| 92 |
+
# Keep pinned tensors alive until eviction — the async H2D transfer
|
| 93 |
+
# may still be reading from them.
|
| 94 |
+
self._pinned_in_flight[idx] = pinned_refs
|
| 95 |
+
self._on_gpu.add(idx)
|
| 96 |
+
|
| 97 |
+
def evict_to_cpu(self, idx: int, layer: nn.Module) -> None:
|
| 98 |
+
"""Restore source data, freeing the GPU and pinned copies."""
|
| 99 |
+
self._check_idx(idx)
|
| 100 |
+
if idx not in self._on_gpu:
|
| 101 |
+
return
|
| 102 |
+
source = self._source_data[idx]
|
| 103 |
+
for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()):
|
| 104 |
+
param.data = source[name]
|
| 105 |
+
# Release pinned tensors — the H2D transfer is complete by now
|
| 106 |
+
# (the compute stream waited on the prefetch event before using
|
| 107 |
+
# the layer, and we only evict after compute finishes).
|
| 108 |
+
self._pinned_in_flight.pop(idx, None)
|
| 109 |
+
self._on_gpu.discard(idx)
|
| 110 |
+
|
| 111 |
+
def cleanup(self) -> None:
|
| 112 |
+
"""Release all source data and in-flight pinned references.
|
| 113 |
+
After this call, the source tensors can be garbage-collected once
|
| 114 |
+
the layer parameters (which still reference them via ``.data``) are
|
| 115 |
+
also released (e.g. via ``.to("meta")``).
|
| 116 |
+
"""
|
| 117 |
+
for source_dict in self._source_data:
|
| 118 |
+
source_dict.clear()
|
| 119 |
+
self._source_data.clear()
|
| 120 |
+
self._pinned_in_flight.clear()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class _AsyncPrefetcher:
|
| 124 |
+
"""Issues H2D transfers on a dedicated CUDA stream.
|
| 125 |
+
Uses per-layer CUDA events so that the compute stream only waits for the
|
| 126 |
+
specific layer it needs, not all pending transfers.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, store: _LayerStore, layers: nn.ModuleList) -> None:
|
| 130 |
+
self._store = store
|
| 131 |
+
self._layers = layers
|
| 132 |
+
self._stream = torch.cuda.Stream(device=store.target_device)
|
| 133 |
+
self._events: dict[int, torch.cuda.Event] = {}
|
| 134 |
+
|
| 135 |
+
def prefetch(self, idx: int) -> None:
|
| 136 |
+
"""Begin async transfer of layer *idx* to GPU (no-op if already there)."""
|
| 137 |
+
if self._store.is_on_gpu(idx) or idx in self._events:
|
| 138 |
+
return
|
| 139 |
+
with torch.cuda.stream(self._stream):
|
| 140 |
+
self._store.move_to_gpu(idx, self._layers[idx], non_blocking=True)
|
| 141 |
+
event = torch.cuda.Event()
|
| 142 |
+
event.record(self._stream)
|
| 143 |
+
self._events[idx] = event
|
| 144 |
+
|
| 145 |
+
def wait(self, idx: int) -> None:
|
| 146 |
+
"""Block the compute stream until layer *idx* transfer is complete."""
|
| 147 |
+
event = self._events.pop(idx, None)
|
| 148 |
+
if event is not None:
|
| 149 |
+
torch.cuda.current_stream(self._store.target_device).wait_event(event)
|
| 150 |
+
|
| 151 |
+
def cleanup(self) -> None:
|
| 152 |
+
"""Drain pending work and release CUDA stream/event resources."""
|
| 153 |
+
self._events.clear()
|
| 154 |
+
self._stream = None
|
| 155 |
+
self._layers = None
|
| 156 |
+
self._store = None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class LayerStreamingWrapper(nn.Module):
|
| 160 |
+
"""Wraps a model to stream its sequential layers between CPU and GPU.
|
| 161 |
+
Each layer is evicted immediately after its forward completes, and
|
| 162 |
+
prefetch wraps around using modular indexing so the end of one forward
|
| 163 |
+
pass prepares early layers for the next.
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
model:
|
| 167 |
+
The model to wrap, with all parameters on **CPU**.
|
| 168 |
+
layers_attr:
|
| 169 |
+
Dotted attribute path to the ``nn.ModuleList`` of sequential layers
|
| 170 |
+
(e.g. ``"transformer_blocks"`` or ``"model.language_model.layers"``).
|
| 171 |
+
target_device:
|
| 172 |
+
The GPU device to use for compute.
|
| 173 |
+
prefetch_count:
|
| 174 |
+
How many layers ahead to prefetch. The maximum number of layers on
|
| 175 |
+
GPU at once is ``1 + prefetch_count``. Must be >= 1.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
model: nn.Module,
|
| 181 |
+
layers_attr: str,
|
| 182 |
+
target_device: torch.device,
|
| 183 |
+
prefetch_count: int = 2,
|
| 184 |
+
) -> None:
|
| 185 |
+
if prefetch_count < 1:
|
| 186 |
+
raise ValueError("prefetch_count must be >= 1")
|
| 187 |
+
super().__init__()
|
| 188 |
+
# Store the wrapped model as a submodule so parameters are discoverable.
|
| 189 |
+
self._model = model
|
| 190 |
+
self._layers = _resolve_attr(model, layers_attr)
|
| 191 |
+
self._target_device = target_device
|
| 192 |
+
# Clamp: no point prefetching more than num_layers - 1 (the rest are evicted).
|
| 193 |
+
self._prefetch_count = min(prefetch_count, len(self._layers) - 1)
|
| 194 |
+
self._hooks: list[torch.utils.hooks.RemovableHandle] = []
|
| 195 |
+
|
| 196 |
+
self._setup()
|
| 197 |
+
|
| 198 |
+
# ------------------------------------------------------------------
|
| 199 |
+
# Setup / teardown
|
| 200 |
+
# ------------------------------------------------------------------
|
| 201 |
+
|
| 202 |
+
def _setup(self) -> None:
|
| 203 |
+
# 1. Build the pinned CPU store (copies all layer tensors to pinned memory).
|
| 204 |
+
self._store = _LayerStore(self._layers, self._target_device)
|
| 205 |
+
|
| 206 |
+
# 2. Move all NON-layer params/buffers to GPU.
|
| 207 |
+
layer_tensor_ids: set[int] = set()
|
| 208 |
+
for layer in self._layers:
|
| 209 |
+
for t in itertools.chain(layer.parameters(), layer.buffers()):
|
| 210 |
+
layer_tensor_ids.add(id(t))
|
| 211 |
+
|
| 212 |
+
for p in self._model.parameters():
|
| 213 |
+
if id(p) not in layer_tensor_ids:
|
| 214 |
+
p.data = p.data.to(self._target_device)
|
| 215 |
+
for b in self._model.buffers():
|
| 216 |
+
if id(b) not in layer_tensor_ids:
|
| 217 |
+
b.data = b.data.to(self._target_device)
|
| 218 |
+
|
| 219 |
+
# 3. Pre-load the first (1 + prefetch_count) layers synchronously.
|
| 220 |
+
for idx in range(min(self._prefetch_count + 1, len(self._layers))):
|
| 221 |
+
self._store.move_to_gpu(idx, self._layers[idx])
|
| 222 |
+
|
| 223 |
+
# 4. Create the async prefetcher and register hooks.
|
| 224 |
+
self._prefetcher = _AsyncPrefetcher(self._store, self._layers)
|
| 225 |
+
self._register_hooks()
|
| 226 |
+
|
| 227 |
+
def _register_hooks(self) -> None:
|
| 228 |
+
idx_map: dict[int, int] = {id(layer): idx for idx, layer in enumerate(self._layers)}
|
| 229 |
+
num_layers = len(self._layers)
|
| 230 |
+
|
| 231 |
+
compute_stream = torch.cuda.current_stream(self._target_device)
|
| 232 |
+
|
| 233 |
+
def _pre_hook(
|
| 234 |
+
module: nn.Module,
|
| 235 |
+
_args: Any, # noqa: ANN401
|
| 236 |
+
*,
|
| 237 |
+
idx: int,
|
| 238 |
+
) -> None:
|
| 239 |
+
# Wait only for THIS layer's H2D transfer (not all pending ones).
|
| 240 |
+
self._prefetcher.wait(idx)
|
| 241 |
+
if not self._store.is_on_gpu(idx):
|
| 242 |
+
self._store.move_to_gpu(idx, module)
|
| 243 |
+
|
| 244 |
+
# Record that the compute stream will read these weight tensors.
|
| 245 |
+
# They were allocated on the prefetch stream, so without this the
|
| 246 |
+
# caching allocator would allow the prefetch stream to reuse their
|
| 247 |
+
# memory immediately after eviction — even if the compute kernel
|
| 248 |
+
# that reads them hasn't finished yet.
|
| 249 |
+
for param in itertools.chain(module.parameters(), module.buffers()):
|
| 250 |
+
param.data.record_stream(compute_stream)
|
| 251 |
+
|
| 252 |
+
# Kick off prefetch for upcoming layers (wraps around for next pass).
|
| 253 |
+
for offset in range(1, self._prefetch_count + 1):
|
| 254 |
+
self._prefetcher.prefetch((idx + offset) % num_layers)
|
| 255 |
+
|
| 256 |
+
def _post_hook(
|
| 257 |
+
module: nn.Module,
|
| 258 |
+
_args: Any, # noqa: ANN401
|
| 259 |
+
_output: Any, # noqa: ANN401
|
| 260 |
+
*,
|
| 261 |
+
idx: int,
|
| 262 |
+
) -> None:
|
| 263 |
+
# Evict this layer immediately — its computation is done.
|
| 264 |
+
self._store.evict_to_cpu(idx, module)
|
| 265 |
+
|
| 266 |
+
for layer in self._layers:
|
| 267 |
+
idx = idx_map[id(layer)]
|
| 268 |
+
h1 = layer.register_forward_pre_hook(functools.partial(_pre_hook, idx=idx))
|
| 269 |
+
h2 = layer.register_forward_hook(functools.partial(_post_hook, idx=idx))
|
| 270 |
+
self._hooks.extend([h1, h2])
|
| 271 |
+
|
| 272 |
+
def teardown(self) -> None:
|
| 273 |
+
"""Remove hooks, release resources, and move parameters back to CPU.
|
| 274 |
+
After this call the wrapper is inert: hooks are removed, the prefetch
|
| 275 |
+
stream is drained and destroyed, all parameters reside on CPU, and the
|
| 276 |
+
``_LayerStore`` source data references are cleared. Callers should
|
| 277 |
+
still follow up with ``.to("meta")`` to release the CPU copies if the
|
| 278 |
+
model is no longer needed.
|
| 279 |
+
"""
|
| 280 |
+
for h in self._hooks:
|
| 281 |
+
h.remove()
|
| 282 |
+
self._hooks.clear()
|
| 283 |
+
|
| 284 |
+
# Drain all in-flight async H2D copies, then release stream resources.
|
| 285 |
+
# Without the synchronize, clearing the stream/events can trigger
|
| 286 |
+
# use-after-free at the CUDA driver level.
|
| 287 |
+
torch.cuda.synchronize(device=self._target_device)
|
| 288 |
+
if self._prefetcher is not None:
|
| 289 |
+
self._prefetcher.cleanup()
|
| 290 |
+
self._prefetcher = None
|
| 291 |
+
|
| 292 |
+
# Move everything to CPU.
|
| 293 |
+
for idx, layer in enumerate(self._layers):
|
| 294 |
+
self._store.evict_to_cpu(idx, layer)
|
| 295 |
+
|
| 296 |
+
for p in self._model.parameters():
|
| 297 |
+
p.data = p.data.to("cpu")
|
| 298 |
+
for b in self._model.buffers():
|
| 299 |
+
b.data = b.data.to("cpu")
|
| 300 |
+
|
| 301 |
+
# Release source data references. After evict_to_cpu() the layer
|
| 302 |
+
# params point to the source data. The caller is expected to follow
|
| 303 |
+
# up with .to("meta") to drop the param refs; cleanup() drops the
|
| 304 |
+
# store's refs.
|
| 305 |
+
self._store.cleanup()
|
| 306 |
+
|
| 307 |
+
# ------------------------------------------------------------------
|
| 308 |
+
# Forward and attribute delegation
|
| 309 |
+
# ------------------------------------------------------------------
|
| 310 |
+
|
| 311 |
+
def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
| 312 |
+
return self._model(*args, **kwargs)
|
| 313 |
+
|
| 314 |
+
def __getattr__(self, name: str) -> Any: # noqa: ANN401
|
| 315 |
+
"""Proxy attribute access to the wrapped model.
|
| 316 |
+
This allows calling methods like ``encode()`` on a wrapped
|
| 317 |
+
GemmaTextEncoder without the caller needing to know about the wrapper.
|
| 318 |
+
``nn.Module.__getattr__`` is only called when normal attribute lookup
|
| 319 |
+
fails, so ``_model``, ``_store``, etc. are found first via ``__dict__``.
|
| 320 |
+
"""
|
| 321 |
+
try:
|
| 322 |
+
return super().__getattr__(name)
|
| 323 |
+
except AttributeError:
|
| 324 |
+
return getattr(self._model, name)
|
ltx2/ltx_core/loader/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader utilities for model weights, LoRAs, and safetensor operations."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 4 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 5 |
+
from ltx_core.loader.primitives import (
|
| 6 |
+
LoRAAdaptableProtocol,
|
| 7 |
+
LoraPathStrengthAndSDOps,
|
| 8 |
+
LoraStateDictWithStrength,
|
| 9 |
+
ModelBuilderProtocol,
|
| 10 |
+
StateDict,
|
| 11 |
+
StateDictLoader,
|
| 12 |
+
)
|
| 13 |
+
from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
|
| 14 |
+
from ltx_core.loader.sd_ops import (
|
| 15 |
+
LTXV_LORA_COMFY_RENAMING_MAP,
|
| 16 |
+
ContentMatching,
|
| 17 |
+
ContentReplacement,
|
| 18 |
+
KeyValueOperation,
|
| 19 |
+
KeyValueOperationResult,
|
| 20 |
+
SDKeyValueOperation,
|
| 21 |
+
SDOps,
|
| 22 |
+
)
|
| 23 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
|
| 24 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"LTXV_LORA_COMFY_RENAMING_MAP",
|
| 28 |
+
"ContentMatching",
|
| 29 |
+
"ContentReplacement",
|
| 30 |
+
"DummyRegistry",
|
| 31 |
+
"KeyValueOperation",
|
| 32 |
+
"KeyValueOperationResult",
|
| 33 |
+
"LoRAAdaptableProtocol",
|
| 34 |
+
"LoraPathStrengthAndSDOps",
|
| 35 |
+
"LoraStateDictWithStrength",
|
| 36 |
+
"ModelBuilderProtocol",
|
| 37 |
+
"ModuleOps",
|
| 38 |
+
"Registry",
|
| 39 |
+
"SDKeyValueOperation",
|
| 40 |
+
"SDOps",
|
| 41 |
+
"SafetensorsModelStateDictLoader",
|
| 42 |
+
"SafetensorsStateDictLoader",
|
| 43 |
+
"SingleGPUModelBuilder",
|
| 44 |
+
"StateDict",
|
| 45 |
+
"StateDictLoader",
|
| 46 |
+
"StateDictRegistry",
|
| 47 |
+
"apply_loras",
|
| 48 |
+
]
|
ltx2/ltx_core/loader/fuse_loras.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Iterator
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
|
| 6 |
+
from ltx_core.quantization.fp8_cast import _fused_add_round_launch
|
| 7 |
+
from ltx_core.quantization.fp8_scaled_mm import quantize_weight_to_fp8_per_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_device() -> torch.device:
|
| 11 |
+
if torch.cuda.is_available():
|
| 12 |
+
return torch.device("cuda", torch.cuda.current_device())
|
| 13 |
+
return torch.device("cpu")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def fuse_lora_weights(
|
| 17 |
+
model_sd: StateDict,
|
| 18 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 19 |
+
dtype: torch.dtype | None = None,
|
| 20 |
+
) -> Iterator[tuple[str, torch.Tensor]]:
|
| 21 |
+
"""Yield ``(key, fused_tensor)`` for each weight modified by at least one LoRA.
|
| 22 |
+
For scaled-FP8 weights, this includes both the updated ``.weight`` tensor
|
| 23 |
+
and its corresponding ``.weight_scale`` tensor.
|
| 24 |
+
"""
|
| 25 |
+
for key, original_weight in model_sd.sd.items():
|
| 26 |
+
if original_weight is None or key.endswith(".weight_scale"):
|
| 27 |
+
continue
|
| 28 |
+
original_device = original_weight.device
|
| 29 |
+
weight = original_weight.to(device=_get_device())
|
| 30 |
+
target_dtype = dtype if dtype is not None else weight.dtype
|
| 31 |
+
deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
|
| 32 |
+
|
| 33 |
+
deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, weight.device)
|
| 34 |
+
if deltas is None:
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
scale_key = key.replace(".weight", ".weight_scale") if key.endswith(".weight") else None
|
| 38 |
+
is_scaled_fp8 = scale_key is not None and scale_key in model_sd.sd
|
| 39 |
+
|
| 40 |
+
if weight.dtype == torch.float8_e4m3fn:
|
| 41 |
+
if is_scaled_fp8:
|
| 42 |
+
fused = _fuse_delta_with_scaled_fp8(deltas, weight, key, scale_key, model_sd)
|
| 43 |
+
else:
|
| 44 |
+
fused = _fuse_delta_with_cast_fp8(deltas, weight, key, target_dtype)
|
| 45 |
+
elif weight.dtype == torch.bfloat16:
|
| 46 |
+
fused = _fuse_delta_with_bfloat16(deltas, weight, key, target_dtype)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported dtype: {weight.dtype}")
|
| 49 |
+
|
| 50 |
+
for k, v in fused.items():
|
| 51 |
+
yield k, v.to(device=original_device)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_loras(
|
| 55 |
+
model_sd: StateDict,
|
| 56 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 57 |
+
dtype: torch.dtype | None = None,
|
| 58 |
+
destination_sd: StateDict | None = None,
|
| 59 |
+
) -> StateDict:
|
| 60 |
+
if destination_sd is not None:
|
| 61 |
+
sd = destination_sd.sd
|
| 62 |
+
for key, tensor in fuse_lora_weights(model_sd, lora_sd_and_strengths, dtype):
|
| 63 |
+
sd[key] = tensor
|
| 64 |
+
return destination_sd
|
| 65 |
+
|
| 66 |
+
fused = dict(fuse_lora_weights(model_sd, lora_sd_and_strengths, dtype))
|
| 67 |
+
sd = {k: (fused[k] if k in fused else v.clone()) for k, v in model_sd.sd.items()}
|
| 68 |
+
return StateDict(sd, model_sd.device, model_sd.size, model_sd.dtype)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _prepare_deltas(
|
| 72 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
|
| 73 |
+
) -> torch.Tensor | None:
|
| 74 |
+
deltas = []
|
| 75 |
+
prefix = key[: -len(".weight")]
|
| 76 |
+
key_a = f"{prefix}.lora_A.weight"
|
| 77 |
+
key_b = f"{prefix}.lora_B.weight"
|
| 78 |
+
for lsd, coef in lora_sd_and_strengths:
|
| 79 |
+
if key_a not in lsd.sd or key_b not in lsd.sd:
|
| 80 |
+
continue
|
| 81 |
+
a = lsd.sd[key_a].to(device=device)
|
| 82 |
+
b = lsd.sd[key_b].to(device=device)
|
| 83 |
+
product = torch.matmul(b * coef, a)
|
| 84 |
+
del a, b
|
| 85 |
+
deltas.append(product.to(dtype=dtype))
|
| 86 |
+
if len(deltas) == 0:
|
| 87 |
+
return None
|
| 88 |
+
elif len(deltas) == 1:
|
| 89 |
+
return deltas[0]
|
| 90 |
+
return torch.sum(torch.stack(deltas, dim=0), dim=0)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _fuse_delta_with_scaled_fp8(
|
| 94 |
+
deltas: torch.Tensor,
|
| 95 |
+
weight: torch.Tensor,
|
| 96 |
+
key: str,
|
| 97 |
+
scale_key: str,
|
| 98 |
+
model_sd: StateDict,
|
| 99 |
+
) -> dict[str, torch.Tensor]:
|
| 100 |
+
"""Dequantize scaled FP8 weight, add LoRA delta, and re-quantize."""
|
| 101 |
+
weight_scale = model_sd.sd[scale_key]
|
| 102 |
+
|
| 103 |
+
original_weight = weight.t().to(torch.float32) * weight_scale
|
| 104 |
+
|
| 105 |
+
new_weight = original_weight + deltas.to(torch.float32)
|
| 106 |
+
|
| 107 |
+
new_fp8_weight, new_weight_scale = quantize_weight_to_fp8_per_tensor(new_weight)
|
| 108 |
+
return {key: new_fp8_weight, scale_key: new_weight_scale}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _fuse_delta_with_cast_fp8(
|
| 112 |
+
deltas: torch.Tensor,
|
| 113 |
+
weight: torch.Tensor,
|
| 114 |
+
key: str,
|
| 115 |
+
target_dtype: torch.dtype,
|
| 116 |
+
) -> dict[str, torch.Tensor]:
|
| 117 |
+
"""Fuse LoRA delta with cast-only FP8 weight (no scale factor)."""
|
| 118 |
+
if str(weight.device).startswith("cuda"):
|
| 119 |
+
_fused_add_round_launch(deltas, weight, seed=0)
|
| 120 |
+
else:
|
| 121 |
+
deltas.add_(weight.to(dtype=deltas.dtype))
|
| 122 |
+
return {key: deltas.to(dtype=target_dtype)}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _fuse_delta_with_bfloat16(
|
| 126 |
+
deltas: torch.Tensor,
|
| 127 |
+
weight: torch.Tensor,
|
| 128 |
+
key: str,
|
| 129 |
+
target_dtype: torch.dtype,
|
| 130 |
+
) -> dict[str, torch.Tensor]:
|
| 131 |
+
"""Fuse LoRA delta with bfloat16 weight."""
|
| 132 |
+
deltas.add_(weight)
|
| 133 |
+
return {key: deltas.to(dtype=target_dtype)}
|
ltx2/ltx_core/loader/kernels.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa: ANN001, ANN201, ERA001, N803, N806
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@triton.jit
|
| 7 |
+
def fused_add_round_kernel(
|
| 8 |
+
x_ptr,
|
| 9 |
+
output_ptr, # contents will be added to the output
|
| 10 |
+
seed,
|
| 11 |
+
n_elements,
|
| 12 |
+
EXPONENT_BIAS,
|
| 13 |
+
MANTISSA_BITS,
|
| 14 |
+
BLOCK_SIZE: tl.constexpr,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
|
| 18 |
+
and add them to bfloat16 output weights. Might be used to upcast original model weights
|
| 19 |
+
and to further add them to precalculated deltas coming from LoRAs.
|
| 20 |
+
"""
|
| 21 |
+
# Get program ID and compute offsets
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
block_start = pid * BLOCK_SIZE
|
| 24 |
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
| 25 |
+
mask = offsets < n_elements
|
| 26 |
+
|
| 27 |
+
# Load data
|
| 28 |
+
x = tl.load(x_ptr + offsets, mask=mask)
|
| 29 |
+
rand_vals = tl.rand(seed, offsets) - 0.5
|
| 30 |
+
|
| 31 |
+
x = tl.cast(x, tl.float16)
|
| 32 |
+
delta = tl.load(output_ptr + offsets, mask=mask)
|
| 33 |
+
delta = tl.cast(delta, tl.float16)
|
| 34 |
+
x = x + delta
|
| 35 |
+
|
| 36 |
+
x_bits = tl.cast(x, tl.int16, bitcast=True)
|
| 37 |
+
|
| 38 |
+
# Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
|
| 39 |
+
# normal numbers and -14 for subnormals.
|
| 40 |
+
fp16_exponent_bits = (x_bits & 0x7C00) >> 10
|
| 41 |
+
fp16_normals = fp16_exponent_bits > 0
|
| 42 |
+
fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
|
| 43 |
+
|
| 44 |
+
# Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
|
| 45 |
+
exponent = fp16_exponent + EXPONENT_BIAS
|
| 46 |
+
MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
|
| 47 |
+
exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
|
| 48 |
+
exponent = tl.where(exponent < 0, 0, exponent)
|
| 49 |
+
|
| 50 |
+
# Normal ULP exponent, expressed as an fp16 exponent field:
|
| 51 |
+
# (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
|
| 52 |
+
# Simplifies to: fp16_exponent - MANTISSA_BITS + 15
|
| 53 |
+
# See https://en.wikipedia.org/wiki/Unit_in_the_last_place
|
| 54 |
+
eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
|
| 55 |
+
|
| 56 |
+
# Calculate epsilon in the target dtype
|
| 57 |
+
eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
|
| 58 |
+
|
| 59 |
+
# Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
|
| 60 |
+
# fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
|
| 61 |
+
# 16 - EXPONENT_BIAS - MANTISSA_BITS
|
| 62 |
+
eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
|
| 63 |
+
eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
|
| 64 |
+
|
| 65 |
+
# Apply zero mask to epsilon
|
| 66 |
+
eps = tl.where(x == 0, 0.0, eps)
|
| 67 |
+
|
| 68 |
+
# Apply stochastic rounding
|
| 69 |
+
output = tl.cast(x + rand_vals * eps, tl.bfloat16)
|
| 70 |
+
|
| 71 |
+
# Store the result
|
| 72 |
+
tl.store(output_ptr + offsets, output, mask=mask)
|
ltx2/ltx_core/loader/module_ops.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, NamedTuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModuleOps(NamedTuple):
|
| 7 |
+
"""
|
| 8 |
+
Defines a named operation for matching and mutating PyTorch modules.
|
| 9 |
+
Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
name: str
|
| 13 |
+
matcher: Callable[[torch.nn.Module], bool]
|
| 14 |
+
mutator: Callable[[torch.nn.Module], torch.nn.Module]
|
ltx2/ltx_core/loader/primitives.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, NamedTuple, Protocol
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 9 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 10 |
+
from ltx_core.model.model_protocol import ModelType
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from ltx_core.loader.registry import Registry
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(frozen=True)
|
| 17 |
+
class StateDict:
|
| 18 |
+
"""
|
| 19 |
+
Immutable container for a PyTorch state dictionary.
|
| 20 |
+
Contains:
|
| 21 |
+
- sd: Dictionary of tensors (weights, buffers, etc.)
|
| 22 |
+
- device: Device where tensors are stored
|
| 23 |
+
- size: Total memory footprint in bytes
|
| 24 |
+
- dtype: Set of tensor dtypes present
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
sd: dict
|
| 28 |
+
device: torch.device
|
| 29 |
+
size: int
|
| 30 |
+
dtype: set[torch.dtype]
|
| 31 |
+
|
| 32 |
+
def footprint(self) -> tuple[int, torch.device]:
|
| 33 |
+
return self.size, self.device
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class StateDictLoader(Protocol):
|
| 37 |
+
"""
|
| 38 |
+
Protocol for loading state dictionaries from various sources.
|
| 39 |
+
Implementations must provide:
|
| 40 |
+
- metadata: Extract model metadata from a single path
|
| 41 |
+
- load: Load state dict from path(s) and apply SDOps transformations
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def metadata(self, path: str) -> dict:
|
| 45 |
+
"""
|
| 46 |
+
Load metadata from path
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 50 |
+
"""
|
| 51 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ModelBuilderProtocol(Protocol[ModelType]):
|
| 56 |
+
"""
|
| 57 |
+
Protocol for building PyTorch models from configuration dictionaries.
|
| 58 |
+
Implementations must provide:
|
| 59 |
+
- meta_model: Create a model from configuration dictionary and apply module operations
|
| 60 |
+
- build: Create and initialize a model from state dictionary and apply dtype transformations
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
model_sd_ops: SDOps | None
|
| 64 |
+
module_ops: tuple[ModuleOps, ...]
|
| 65 |
+
loras: tuple["LoraPathStrengthAndSDOps", ...]
|
| 66 |
+
registry: "Registry"
|
| 67 |
+
|
| 68 |
+
def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
|
| 69 |
+
"""
|
| 70 |
+
Create a model on the meta device from a configuration dictionary.
|
| 71 |
+
This decouples model creation from weight loading, allowing the model
|
| 72 |
+
architecture to be instantiated without allocating memory for parameters.
|
| 73 |
+
Args:
|
| 74 |
+
config: Model configuration dictionary.
|
| 75 |
+
module_ops: Optional list of module operations to apply (e.g., quantization).
|
| 76 |
+
Returns:
|
| 77 |
+
Model instance on meta device (no actual memory allocated for parameters).
|
| 78 |
+
"""
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]":
|
| 82 |
+
"""Return a copy of this builder with the given state-dict key remapping ops."""
|
| 83 |
+
...
|
| 84 |
+
|
| 85 |
+
def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]":
|
| 86 |
+
"""Return a copy of this builder with the given module operations (e.g. quantization)."""
|
| 87 |
+
...
|
| 88 |
+
|
| 89 |
+
def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]":
|
| 90 |
+
"""Return a copy of this builder with the given LoRAs to fuse at build time."""
|
| 91 |
+
...
|
| 92 |
+
|
| 93 |
+
def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]":
|
| 94 |
+
"""Return a copy of this builder using the given weight registry for allocation."""
|
| 95 |
+
...
|
| 96 |
+
|
| 97 |
+
def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]":
|
| 98 |
+
"""Return a copy of this builder that loads LoRA weights onto the given device."""
|
| 99 |
+
...
|
| 100 |
+
|
| 101 |
+
def build(
|
| 102 |
+
self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object
|
| 103 |
+
) -> ModelType:
|
| 104 |
+
"""
|
| 105 |
+
Build the model
|
| 106 |
+
Args:
|
| 107 |
+
device: Target device for the model
|
| 108 |
+
dtype: Target dtype for the model, if None, uses the dtype of the model_path model
|
| 109 |
+
Returns:
|
| 110 |
+
Model instance
|
| 111 |
+
"""
|
| 112 |
+
...
|
| 113 |
+
|
| 114 |
+
def model_config(self) -> dict:
|
| 115 |
+
"""Return the model configuration dictionary extracted from the checkpoint metadata."""
|
| 116 |
+
...
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class LoRAAdaptableProtocol(Protocol):
|
| 120 |
+
"""
|
| 121 |
+
Protocol for models that can be adapted with LoRAs.
|
| 122 |
+
Implementations must provide:
|
| 123 |
+
- lora: Add a LoRA to the model
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class LoraPathStrengthAndSDOps(NamedTuple):
|
| 131 |
+
"""
|
| 132 |
+
Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
path: str
|
| 136 |
+
strength: float
|
| 137 |
+
sd_ops: SDOps
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class LoraStateDictWithStrength(NamedTuple):
|
| 141 |
+
"""
|
| 142 |
+
Tuple containing a LoRA state dict and strength for applying to the model.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
state_dict: StateDict
|
| 146 |
+
strength: float
|
ltx2/ltx_core/loader/registry.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Protocol
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.primitives import StateDict
|
| 8 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Registry(Protocol):
|
| 12 |
+
"""
|
| 13 |
+
Protocol for managing state dictionaries in a registry.
|
| 14 |
+
It is used to store state dictionaries and reuse them later without loading them again.
|
| 15 |
+
Implementations must provide:
|
| 16 |
+
- add: Add a state dictionary to the registry
|
| 17 |
+
- pop: Remove a state dictionary from the registry
|
| 18 |
+
- get: Retrieve a state dictionary from the registry
|
| 19 |
+
- clear: Clear all state dictionaries from the registry
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
|
| 23 |
+
|
| 24 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 25 |
+
|
| 26 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 27 |
+
|
| 28 |
+
def clear(self) -> None: ...
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DummyRegistry(Registry):
|
| 32 |
+
"""
|
| 33 |
+
Dummy registry that does not store state dictionaries.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def clear(self) -> None:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class StateDictRegistry(Registry):
|
| 51 |
+
"""
|
| 52 |
+
Registry that stores state dictionaries in a dictionary.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
_state_dicts: dict[str, StateDict] = field(default_factory=dict)
|
| 56 |
+
_lock: threading.Lock = field(default_factory=threading.Lock)
|
| 57 |
+
|
| 58 |
+
def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
|
| 59 |
+
m = hashlib.sha256()
|
| 60 |
+
parts = [str(Path(p).resolve()) for p in paths]
|
| 61 |
+
if sd_ops is not None:
|
| 62 |
+
parts.append(sd_ops.name)
|
| 63 |
+
m.update("\0".join(parts).encode("utf-8"))
|
| 64 |
+
return m.hexdigest()
|
| 65 |
+
|
| 66 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
|
| 67 |
+
sd_id = self._generate_id(paths, sd_ops)
|
| 68 |
+
with self._lock:
|
| 69 |
+
if sd_id in self._state_dicts:
|
| 70 |
+
raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
|
| 71 |
+
self._state_dicts[sd_id] = state_dict
|
| 72 |
+
return sd_id
|
| 73 |
+
|
| 74 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 75 |
+
with self._lock:
|
| 76 |
+
return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
|
| 77 |
+
|
| 78 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 79 |
+
with self._lock:
|
| 80 |
+
return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
|
| 81 |
+
|
| 82 |
+
def clear(self) -> None:
|
| 83 |
+
with self._lock:
|
| 84 |
+
self._state_dicts.clear()
|
ltx2/ltx_core/loader/sd_ops.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, replace
|
| 2 |
+
from typing import NamedTuple, Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True, slots=True)
|
| 8 |
+
class ContentReplacement:
|
| 9 |
+
"""
|
| 10 |
+
Represents a content replacement operation.
|
| 11 |
+
Used to replace a specific content with a replacement in a state dict key.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
content: str
|
| 15 |
+
replacement: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True, slots=True)
|
| 19 |
+
class ContentMatching:
|
| 20 |
+
"""
|
| 21 |
+
Represents a content matching operation.
|
| 22 |
+
Used to match a specific prefix and suffix in a state dict key.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
prefix: str = ""
|
| 26 |
+
suffix: str = ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KeyValueOperationResult(NamedTuple):
|
| 30 |
+
"""
|
| 31 |
+
Represents the result of a key-value operation.
|
| 32 |
+
Contains the new key and value after the operation has been applied.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
new_key: str
|
| 36 |
+
new_value: torch.Tensor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class KeyValueOperation(Protocol):
|
| 40 |
+
"""
|
| 41 |
+
Protocol for key-value operations.
|
| 42 |
+
Used to apply operations to a specific key and value in a state dict.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass(frozen=True, slots=True)
|
| 49 |
+
class SDKeyValueOperation:
|
| 50 |
+
"""
|
| 51 |
+
Represents a key-value operation.
|
| 52 |
+
Used to apply operations to a specific key and value in a state dict.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
key_matcher: ContentMatching
|
| 56 |
+
kv_operation: KeyValueOperation
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True, slots=True)
|
| 60 |
+
class SDOps:
|
| 61 |
+
"""Immutable class representing state dict key operations."""
|
| 62 |
+
|
| 63 |
+
name: str
|
| 64 |
+
mapping: tuple[
|
| 65 |
+
ContentReplacement | ContentMatching | SDKeyValueOperation, ...
|
| 66 |
+
] = () # Immutable tuple of (key, value) pairs
|
| 67 |
+
allowed_keys: frozenset[str] | None = None
|
| 68 |
+
|
| 69 |
+
def with_replacement(self, content: str, replacement: str) -> "SDOps":
|
| 70 |
+
"""Create a new SDOps instance with the specified replacement added to the mapping."""
|
| 71 |
+
|
| 72 |
+
new_mapping = (*self.mapping, ContentReplacement(content, replacement))
|
| 73 |
+
return replace(self, mapping=new_mapping)
|
| 74 |
+
|
| 75 |
+
def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
|
| 76 |
+
"""Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
|
| 77 |
+
|
| 78 |
+
new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
|
| 79 |
+
return replace(self, mapping=new_mapping)
|
| 80 |
+
|
| 81 |
+
def with_additional_allowed_keys(self, keys: frozenset[str]) -> "SDOps":
|
| 82 |
+
"""Create a new SDOps instance that only passes keys present in *keys* (post-replacement).
|
| 83 |
+
If allowed_keys already exists, the sets are merged via union.
|
| 84 |
+
"""
|
| 85 |
+
merged = frozenset(keys) | self.allowed_keys if self.allowed_keys is not None else frozenset(keys)
|
| 86 |
+
return replace(self, allowed_keys=merged)
|
| 87 |
+
|
| 88 |
+
def with_kv_operation(
|
| 89 |
+
self,
|
| 90 |
+
operation: KeyValueOperation,
|
| 91 |
+
key_prefix: str = "",
|
| 92 |
+
key_suffix: str = "",
|
| 93 |
+
) -> "SDOps":
|
| 94 |
+
"""Create a new SDOps instance with the specified value operation added to the mapping."""
|
| 95 |
+
key_matcher = ContentMatching(key_prefix, key_suffix)
|
| 96 |
+
sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
|
| 97 |
+
new_mapping = (*self.mapping, sd_kv_operation)
|
| 98 |
+
return replace(self, mapping=new_mapping)
|
| 99 |
+
|
| 100 |
+
def apply_to_key(self, key: str) -> str | None:
|
| 101 |
+
"""Apply the mapping to the given name."""
|
| 102 |
+
matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
|
| 103 |
+
valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
|
| 104 |
+
if not valid:
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
for replacement in self.mapping:
|
| 108 |
+
if not isinstance(replacement, ContentReplacement):
|
| 109 |
+
continue
|
| 110 |
+
if replacement.content in key:
|
| 111 |
+
key = key.replace(replacement.content, replacement.replacement)
|
| 112 |
+
|
| 113 |
+
if self.allowed_keys is not None and key not in self.allowed_keys:
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
return key
|
| 117 |
+
|
| 118 |
+
def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 119 |
+
"""Apply the value operation to the given name and associated value."""
|
| 120 |
+
for operation in self.mapping:
|
| 121 |
+
if not isinstance(operation, SDKeyValueOperation):
|
| 122 |
+
continue
|
| 123 |
+
if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
|
| 124 |
+
return operation.kv_operation(key, value)
|
| 125 |
+
return [KeyValueOperationResult(key, value)]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Predefined SDOps instances
|
| 129 |
+
LTXV_LORA_COMFY_RENAMING_MAP = (
|
| 130 |
+
SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
LTXV_LORA_COMFY_TARGET_MAP = (
|
| 134 |
+
SDOps("LTXV_LORA_COMFY_TARGET_MAP")
|
| 135 |
+
.with_matching()
|
| 136 |
+
.with_replacement("diffusion_model.", "")
|
| 137 |
+
.with_replacement(".lora_A.weight", ".weight")
|
| 138 |
+
.with_replacement(".lora_B.weight", ".weight")
|
| 139 |
+
)
|
ltx2/ltx_core/loader/sft_loader.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import safetensors
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.loader.primitives import StateDict, StateDictLoader
|
| 7 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafetensorsStateDictLoader(StateDictLoader):
|
| 11 |
+
"""
|
| 12 |
+
Loads weights from safetensors files without metadata support.
|
| 13 |
+
Use this for loading raw weight files. For model files that include
|
| 14 |
+
configuration metadata, use SafetensorsModelStateDictLoader instead.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def metadata(self, path: str) -> dict:
|
| 18 |
+
raise NotImplementedError("Not implemented")
|
| 19 |
+
|
| 20 |
+
def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
|
| 21 |
+
"""
|
| 22 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 23 |
+
"""
|
| 24 |
+
sd = {}
|
| 25 |
+
size = 0
|
| 26 |
+
dtype = set()
|
| 27 |
+
device = device or torch.device("cpu")
|
| 28 |
+
model_paths = path if isinstance(path, list) else [path]
|
| 29 |
+
for shard_path in model_paths:
|
| 30 |
+
with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
|
| 31 |
+
safetensor_keys = f.keys()
|
| 32 |
+
for name in safetensor_keys:
|
| 33 |
+
expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
|
| 34 |
+
if expected_name is None:
|
| 35 |
+
continue
|
| 36 |
+
value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
|
| 37 |
+
key_value_pairs = ((expected_name, value),)
|
| 38 |
+
if sd_ops is not None:
|
| 39 |
+
key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
|
| 40 |
+
for key, value in key_value_pairs:
|
| 41 |
+
size += value.nbytes
|
| 42 |
+
dtype.add(value.dtype)
|
| 43 |
+
sd[key] = value
|
| 44 |
+
|
| 45 |
+
return StateDict(sd=sd, device=device, size=size, dtype=dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SafetensorsModelStateDictLoader(StateDictLoader):
|
| 49 |
+
"""
|
| 50 |
+
Loads weights and configuration metadata from safetensors model files.
|
| 51 |
+
Unlike SafetensorsStateDictLoader, this loader can read model configuration
|
| 52 |
+
from the safetensors file metadata via the metadata() method.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
|
| 56 |
+
self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
|
| 57 |
+
|
| 58 |
+
def metadata(self, path: str) -> dict:
|
| 59 |
+
with safetensors.safe_open(path, framework="pt") as f:
|
| 60 |
+
meta = f.metadata()
|
| 61 |
+
if meta is None or "config" not in meta:
|
| 62 |
+
return {}
|
| 63 |
+
return json.loads(meta["config"])
|
| 64 |
+
|
| 65 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 66 |
+
return self.weight_loader.load(path, sd_ops, device)
|
ltx2/ltx_core/loader/single_gpu_model_builder.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass, field, replace
|
| 3 |
+
from typing import Generic
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 8 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 9 |
+
from ltx_core.loader.primitives import (
|
| 10 |
+
LoRAAdaptableProtocol,
|
| 11 |
+
LoraPathStrengthAndSDOps,
|
| 12 |
+
LoraStateDictWithStrength,
|
| 13 |
+
ModelBuilderProtocol,
|
| 14 |
+
StateDict,
|
| 15 |
+
StateDictLoader,
|
| 16 |
+
)
|
| 17 |
+
from ltx_core.loader.registry import DummyRegistry, Registry
|
| 18 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 19 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
|
| 20 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 21 |
+
|
| 22 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
|
| 27 |
+
"""
|
| 28 |
+
Builder for PyTorch models residing on a single GPU.
|
| 29 |
+
Attributes:
|
| 30 |
+
model_class_configurator: Class responsible for constructing the model from a config dict.
|
| 31 |
+
model_path: Path (or tuple of shard paths) to the model's `.safetensors` checkpoint(s).
|
| 32 |
+
model_sd_ops: Optional state-dict operations applied when loading the model weights.
|
| 33 |
+
module_ops: Sequence of module-level mutations applied to the meta model before weight loading.
|
| 34 |
+
loras: Sequence of LoRA adapters (path, strength, optional sd_ops) to fuse into the model.
|
| 35 |
+
model_loader: Strategy for loading state dicts from disk. Defaults to
|
| 36 |
+
:class:`SafetensorsModelStateDictLoader`.
|
| 37 |
+
registry: Cache for already-loaded state dicts. Defaults to :class:`DummyRegistry` (no caching).
|
| 38 |
+
lora_load_device: Device used when loading LoRA weight tensors from disk. Defaults to
|
| 39 |
+
``torch.device("cpu")``, which keeps LoRA weights in CPU memory and transfers them to
|
| 40 |
+
the target GPU sequentially during fusion, reducing peak GPU memory usage compared to
|
| 41 |
+
loading all LoRA weights directly onto the GPU at once.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
model_class_configurator: type[ModelConfigurator[ModelType]]
|
| 45 |
+
model_path: str | tuple[str, ...]
|
| 46 |
+
model_sd_ops: SDOps | None = None
|
| 47 |
+
module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
|
| 48 |
+
loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
|
| 49 |
+
model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
|
| 50 |
+
registry: Registry = field(default_factory=DummyRegistry)
|
| 51 |
+
lora_load_device: torch.device = field(default_factory=lambda: torch.device("cpu"))
|
| 52 |
+
|
| 53 |
+
def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
|
| 54 |
+
return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
|
| 55 |
+
|
| 56 |
+
def with_sd_ops(self, sd_ops: SDOps | None) -> "SingleGPUModelBuilder":
|
| 57 |
+
return replace(self, model_sd_ops=sd_ops)
|
| 58 |
+
|
| 59 |
+
def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "SingleGPUModelBuilder":
|
| 60 |
+
return replace(self, module_ops=module_ops)
|
| 61 |
+
|
| 62 |
+
def with_loras(self, loras: tuple[LoraPathStrengthAndSDOps, ...]) -> "SingleGPUModelBuilder":
|
| 63 |
+
return replace(self, loras=loras)
|
| 64 |
+
|
| 65 |
+
def with_registry(self, registry: Registry) -> "SingleGPUModelBuilder":
|
| 66 |
+
return replace(self, registry=registry)
|
| 67 |
+
|
| 68 |
+
def with_lora_load_device(self, device: torch.device) -> "SingleGPUModelBuilder":
|
| 69 |
+
return replace(self, lora_load_device=device)
|
| 70 |
+
|
| 71 |
+
def model_config(self) -> dict:
|
| 72 |
+
first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
|
| 73 |
+
return self.model_loader.metadata(first_shard_path)
|
| 74 |
+
|
| 75 |
+
def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
|
| 76 |
+
with torch.device("meta"):
|
| 77 |
+
model = self.model_class_configurator.from_config(config)
|
| 78 |
+
for module_op in module_ops:
|
| 79 |
+
if module_op.matcher(model):
|
| 80 |
+
model = module_op.mutator(model)
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
def load_sd(
|
| 84 |
+
self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
|
| 85 |
+
) -> StateDict:
|
| 86 |
+
state_dict = registry.get(paths, sd_ops)
|
| 87 |
+
if state_dict is None:
|
| 88 |
+
state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
|
| 89 |
+
registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
|
| 90 |
+
return state_dict
|
| 91 |
+
|
| 92 |
+
def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
|
| 93 |
+
uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
|
| 94 |
+
uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
|
| 95 |
+
if uninitialized_params or uninitialized_buffers:
|
| 96 |
+
logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
|
| 97 |
+
return meta_model
|
| 98 |
+
retval = meta_model.to(device)
|
| 99 |
+
return retval
|
| 100 |
+
|
| 101 |
+
def build(
|
| 102 |
+
self,
|
| 103 |
+
device: torch.device | None = None,
|
| 104 |
+
dtype: torch.dtype | None = None,
|
| 105 |
+
**kwargs: object, # noqa: ARG002
|
| 106 |
+
) -> ModelType:
|
| 107 |
+
device = torch.device("cuda") if device is None else device
|
| 108 |
+
config = self.model_config()
|
| 109 |
+
meta_model = self.meta_model(config, self.module_ops)
|
| 110 |
+
model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path]
|
| 111 |
+
model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
|
| 112 |
+
|
| 113 |
+
lora_strengths = [lora.strength for lora in self.loras]
|
| 114 |
+
if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
|
| 115 |
+
sd = model_state_dict.sd
|
| 116 |
+
if dtype is not None:
|
| 117 |
+
sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
|
| 118 |
+
meta_model.load_state_dict(sd, strict=False, assign=True)
|
| 119 |
+
return self._return_model(meta_model, device)
|
| 120 |
+
|
| 121 |
+
lora_state_dicts = [
|
| 122 |
+
self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=self.lora_load_device)
|
| 123 |
+
for lora in self.loras
|
| 124 |
+
]
|
| 125 |
+
lora_sd_and_strengths = [
|
| 126 |
+
LoraStateDictWithStrength(sd, strength)
|
| 127 |
+
for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
|
| 128 |
+
]
|
| 129 |
+
final_sd = apply_loras(
|
| 130 |
+
model_sd=model_state_dict,
|
| 131 |
+
lora_sd_and_strengths=lora_sd_and_strengths,
|
| 132 |
+
dtype=dtype,
|
| 133 |
+
destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
|
| 134 |
+
)
|
| 135 |
+
meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
|
| 136 |
+
return self._return_model(meta_model, device)
|
ltx2/ltx_core/modality_tiling.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video modality tiling helpers.
|
| 2 |
+
Provides :class:`VideoModalityTilingHelper` — a stateless helper that
|
| 3 |
+
tiles and blends video :class:`Modality` token sequences by
|
| 4 |
+
spatial/temporal region. Tile geometry is represented by the existing
|
| 5 |
+
:class:`Tile` NamedTuple from :mod:`ltx_core.tiling`; no distributed
|
| 6 |
+
primitives are required.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, replace
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from ltx_core.model.transformer.modality import Modality
|
| 16 |
+
from ltx_core.tiling import Tile, TileCountConfig, create_tiles, identity_mapping_operation, split_by_count
|
| 17 |
+
from ltx_core.tools import VideoLatentTools
|
| 18 |
+
from ltx_core.types import VideoLatentShape
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class TilingContext:
|
| 23 |
+
"""Opaque context produced by :meth:`VideoModalityTilingHelper.tile_modality`.
|
| 24 |
+
Carries the token-level keep mask and per-conditioning-token blend
|
| 25 |
+
weights needed by :meth:`~VideoModalityTilingHelper.blend`.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
keep_mask: torch.Tensor
|
| 29 |
+
cond_blend_weights: torch.Tensor | None
|
| 30 |
+
"""``(num_kept_cond,)`` — weight for each kept conditioning token,
|
| 31 |
+
equal to ``1 / num_tiles_that_keep_this_token``. ``None`` when
|
| 32 |
+
there are no conditioning tokens."""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class VideoModalityTilingHelper:
|
| 36 |
+
"""Stateless helper that tiles and blends video :class:`Modality` sequences.
|
| 37 |
+
Constructed once with a :class:`TileCountConfig` and
|
| 38 |
+
:class:`VideoLatentTools`. Tiles are computed at construction and
|
| 39 |
+
available via the :attr:`tiles` property. Use :meth:`tile_modality`
|
| 40 |
+
and :meth:`blend` with any tile from that list.
|
| 41 |
+
Usage::
|
| 42 |
+
helper = VideoModalityTilingHelper(tiling, video_tools)
|
| 43 |
+
for tile in helper.tiles:
|
| 44 |
+
tiled_mod, ctx = helper.tile_modality(modality, tile)
|
| 45 |
+
result = run_model(tiled_mod)
|
| 46 |
+
helper.blend(result, tile, ctx, output=output)
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, tiling: TileCountConfig, video_tools: VideoLatentTools) -> None:
|
| 50 |
+
self._patchifier = video_tools.patchifier
|
| 51 |
+
self._latent_shape = video_tools.target_shape
|
| 52 |
+
self._num_generated_tokens = self._patchifier.get_token_count(self._latent_shape)
|
| 53 |
+
self._tiles = create_tiles(
|
| 54 |
+
torch.Size([self._latent_shape.frames, self._latent_shape.height, self._latent_shape.width]),
|
| 55 |
+
splitters=[
|
| 56 |
+
split_by_count(tiling.frames.num_tiles, tiling.frames.overlap),
|
| 57 |
+
split_by_count(tiling.height.num_tiles, tiling.height.overlap),
|
| 58 |
+
split_by_count(tiling.width.num_tiles, tiling.width.overlap),
|
| 59 |
+
],
|
| 60 |
+
mappers=[identity_mapping_operation] * 3,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def tiles(self) -> list[Tile]:
|
| 65 |
+
"""All tiles for the configured tiling layout."""
|
| 66 |
+
return self._tiles
|
| 67 |
+
|
| 68 |
+
# -- tile modality -----------------------------------------------------
|
| 69 |
+
|
| 70 |
+
def tile_modality(self, modality: Modality, tile: Tile) -> tuple[Modality, TilingContext]:
|
| 71 |
+
"""Slice *modality* to the tokens covered by *tile*.
|
| 72 |
+
Selects generated tokens belonging to the tile's spatial region
|
| 73 |
+
and conditioning tokens that overlap with the tile (or have
|
| 74 |
+
negative time coordinates).
|
| 75 |
+
Returns:
|
| 76 |
+
A ``(tiled_modality, context)`` tuple. Pass *context* to
|
| 77 |
+
:meth:`blend` together with the model output.
|
| 78 |
+
"""
|
| 79 |
+
keep_mask = self._keep_mask(modality, tile)
|
| 80 |
+
|
| 81 |
+
tile_attention_mask = None
|
| 82 |
+
if modality.attention_mask is not None:
|
| 83 |
+
keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(1)
|
| 84 |
+
tile_attention_mask = modality.attention_mask[:, keep_indices, :][:, :, keep_indices]
|
| 85 |
+
|
| 86 |
+
tiled = replace(
|
| 87 |
+
modality,
|
| 88 |
+
latent=modality.latent[:, keep_mask, :],
|
| 89 |
+
timesteps=modality.timesteps[:, keep_mask],
|
| 90 |
+
positions=modality.positions[:, :, keep_mask, :],
|
| 91 |
+
attention_mask=tile_attention_mask,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
cond_blend_weights = None
|
| 95 |
+
num_total = modality.latent.shape[1]
|
| 96 |
+
if num_total > self._num_generated_tokens:
|
| 97 |
+
cond_keep = keep_mask[self._num_generated_tokens :]
|
| 98 |
+
# Count how many tiles keep each conditioning token.
|
| 99 |
+
cond_counts = torch.zeros(cond_keep.sum(), dtype=torch.float32)
|
| 100 |
+
for t in self._tiles:
|
| 101 |
+
other_mask = self._keep_mask(modality, t)
|
| 102 |
+
other_cond = other_mask[self._num_generated_tokens :]
|
| 103 |
+
# Map other tile's kept cond tokens into this tile's kept subset.
|
| 104 |
+
cond_counts += other_cond[cond_keep].float()
|
| 105 |
+
cond_blend_weights = 1.0 / cond_counts
|
| 106 |
+
|
| 107 |
+
return tiled, TilingContext(keep_mask=keep_mask, cond_blend_weights=cond_blend_weights)
|
| 108 |
+
|
| 109 |
+
# -- blend -------------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
def blend(
|
| 112 |
+
self,
|
| 113 |
+
tile_to_blend: torch.Tensor,
|
| 114 |
+
tile: Tile,
|
| 115 |
+
context: TilingContext,
|
| 116 |
+
output: torch.Tensor | None = None,
|
| 117 |
+
) -> torch.Tensor:
|
| 118 |
+
"""Blend-weight tile results and accumulate into the full token space.
|
| 119 |
+
Premultiplied (blend-weighted) data is **added** to *output*,
|
| 120 |
+
allowing multiple tiles to be accumulated into the same buffer.
|
| 121 |
+
Args:
|
| 122 |
+
tile_to_blend: Denoised tile tensor ``(B, num_tile_tokens, D)``,
|
| 123 |
+
where the first ``_tile_generated_token_count(tile)``
|
| 124 |
+
entries are generated tokens and the remainder are
|
| 125 |
+
conditioning tokens.
|
| 126 |
+
tile: The :class:`Tile` that was used in :meth:`tile_modality`.
|
| 127 |
+
context: The :class:`TilingContext` returned by :meth:`tile_modality`.
|
| 128 |
+
output: Optional pre-allocated output tensor. When provided
|
| 129 |
+
its shape must be ``(B, num_total_tokens, D)`` and the
|
| 130 |
+
blended tile is **added** into it. When ``None`` a new
|
| 131 |
+
zero-filled tensor is created.
|
| 132 |
+
Returns:
|
| 133 |
+
The output tensor with the blended tile added at the correct
|
| 134 |
+
positions.
|
| 135 |
+
"""
|
| 136 |
+
batch, _, dim = tile_to_blend.shape
|
| 137 |
+
num_tile_gen = self._tile_generated_token_count(tile)
|
| 138 |
+
gen_indices = self._generated_token_indices(tile)
|
| 139 |
+
|
| 140 |
+
num_total_tokens = context.keep_mask.shape[0]
|
| 141 |
+
expected_shape = (batch, num_total_tokens, dim)
|
| 142 |
+
|
| 143 |
+
if output is not None:
|
| 144 |
+
if output.shape != expected_shape:
|
| 145 |
+
raise ValueError(f"Expected output shape {expected_shape}, got {output.shape}")
|
| 146 |
+
result = output
|
| 147 |
+
else:
|
| 148 |
+
result = torch.zeros(*expected_shape, device=tile_to_blend.device, dtype=tile_to_blend.dtype)
|
| 149 |
+
|
| 150 |
+
# Blend mask is (tile_F, tile_H, tile_W) — one weight per token in row-major order.
|
| 151 |
+
blend_weights = tile.blend_mask.reshape(-1).to(device=tile_to_blend.device, dtype=tile_to_blend.dtype)
|
| 152 |
+
tile_gen = tile_to_blend[:, :num_tile_gen, :] * blend_weights[None, :, None]
|
| 153 |
+
|
| 154 |
+
result[:, gen_indices, :] += tile_gen
|
| 155 |
+
|
| 156 |
+
# Scatter kept conditioning tokens, weighted by 1/N where N is
|
| 157 |
+
# the number of tiles that keep each token (so they sum to 1).
|
| 158 |
+
if num_total_tokens > self._num_generated_tokens and context.cond_blend_weights is not None:
|
| 159 |
+
cond_keep = context.keep_mask[self._num_generated_tokens :]
|
| 160 |
+
cond_indices = self._num_generated_tokens + cond_keep.nonzero(as_tuple=False).squeeze(1)
|
| 161 |
+
weights = context.cond_blend_weights.to(device=tile_to_blend.device, dtype=tile_to_blend.dtype)
|
| 162 |
+
result[:, cond_indices, :] += tile_to_blend[:, num_tile_gen:, :] * weights[None, :, None]
|
| 163 |
+
|
| 164 |
+
return result
|
| 165 |
+
|
| 166 |
+
# -- private -----------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
def _tile_generated_token_count(self, tile: Tile) -> int:
|
| 169 |
+
"""Number of generated tokens in *tile*."""
|
| 170 |
+
frame_slice, height_slice, width_slice = tile.in_coords
|
| 171 |
+
tile_shape = VideoLatentShape(
|
| 172 |
+
batch=self._latent_shape.batch,
|
| 173 |
+
channels=self._latent_shape.channels,
|
| 174 |
+
frames=frame_slice.stop - frame_slice.start,
|
| 175 |
+
height=height_slice.stop - height_slice.start,
|
| 176 |
+
width=width_slice.stop - width_slice.start,
|
| 177 |
+
)
|
| 178 |
+
return self._patchifier.get_token_count(tile_shape)
|
| 179 |
+
|
| 180 |
+
def _generated_token_indices(self, tile: Tile) -> torch.Tensor:
|
| 181 |
+
"""Flat token indices of *tile*'s generated tokens in the full sequence."""
|
| 182 |
+
frame_slice, height_slice, width_slice = tile.in_coords
|
| 183 |
+
f = torch.arange(frame_slice.start, frame_slice.stop)
|
| 184 |
+
h = torch.arange(height_slice.start, height_slice.stop)
|
| 185 |
+
w = torch.arange(width_slice.start, width_slice.stop)
|
| 186 |
+
return (
|
| 187 |
+
f[:, None, None] * self._latent_shape.height * self._latent_shape.width
|
| 188 |
+
+ h[None, :, None] * self._latent_shape.width
|
| 189 |
+
+ w[None, None, :]
|
| 190 |
+
).reshape(-1)
|
| 191 |
+
|
| 192 |
+
def _keep_mask(self, modality: Modality, tile: Tile) -> torch.Tensor:
|
| 193 |
+
"""Boolean mask ``(num_total_tokens,)`` — True for tokens the tile processes.
|
| 194 |
+
Generated tokens are selected by grid position. Conditioning
|
| 195 |
+
tokens are kept when their ``[start, end)`` intervals overlap
|
| 196 |
+
the tile in all three dimensions, or when they have a negative
|
| 197 |
+
time coordinate (reference tokens).
|
| 198 |
+
"""
|
| 199 |
+
num_total = modality.latent.shape[1]
|
| 200 |
+
mask = torch.zeros(num_total, dtype=torch.bool)
|
| 201 |
+
|
| 202 |
+
gen_indices = self._generated_token_indices(tile)
|
| 203 |
+
mask[gen_indices] = True
|
| 204 |
+
|
| 205 |
+
if num_total > self._num_generated_tokens:
|
| 206 |
+
gen_positions = modality.positions[:, :, gen_indices, :] # (B, 3, num_tile_gen, 2)
|
| 207 |
+
tile_start = gen_positions[..., 0].amin(dim=2) # (B, 3)
|
| 208 |
+
tile_end = gen_positions[..., 1].amax(dim=2) # (B, 3)
|
| 209 |
+
|
| 210 |
+
cond_positions = modality.positions[:, :, self._num_generated_tokens :, :] # (B, 3, num_cond, 2)
|
| 211 |
+
|
| 212 |
+
overlaps = (cond_positions[..., 0] < tile_end.unsqueeze(2)) & (
|
| 213 |
+
cond_positions[..., 1] > tile_start.unsqueeze(2)
|
| 214 |
+
) # (B, 3, num_cond)
|
| 215 |
+
overlaps_all_dims = overlaps.all(dim=1) # (B, num_cond)
|
| 216 |
+
|
| 217 |
+
has_negative_time = cond_positions[:, 0, :, 0] < 0 # (B, num_cond)
|
| 218 |
+
|
| 219 |
+
keep_cond = (overlaps_all_dims | has_negative_time).any(dim=0) # (num_cond,)
|
| 220 |
+
mask[self._num_generated_tokens :] = keep_cond
|
| 221 |
+
|
| 222 |
+
return mask
|
ltx2/ltx_core/model/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model definitions for LTX-2."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ModelConfigurator",
|
| 7 |
+
"ModelType",
|
| 8 |
+
]
|
ltx2/ltx_core/model/audio_vae/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio VAE model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio
|
| 4 |
+
from ltx_core.model.audio_vae.model_configurator import (
|
| 5 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 6 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 7 |
+
VOCODER_COMFY_KEYS_FILTER,
|
| 8 |
+
AudioDecoderConfigurator,
|
| 9 |
+
AudioEncoderConfigurator,
|
| 10 |
+
VocoderConfigurator,
|
| 11 |
+
)
|
| 12 |
+
from ltx_core.model.audio_vae.ops import AudioProcessor
|
| 13 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder, VocoderWithBWE
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
|
| 17 |
+
"AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
|
| 18 |
+
"VOCODER_COMFY_KEYS_FILTER",
|
| 19 |
+
"AudioDecoder",
|
| 20 |
+
"AudioDecoderConfigurator",
|
| 21 |
+
"AudioEncoder",
|
| 22 |
+
"AudioEncoderConfigurator",
|
| 23 |
+
"AudioProcessor",
|
| 24 |
+
"Vocoder",
|
| 25 |
+
"VocoderConfigurator",
|
| 26 |
+
"VocoderWithBWE",
|
| 27 |
+
"decode_audio",
|
| 28 |
+
"encode_audio",
|
| 29 |
+
]
|
ltx2/ltx_core/model/audio_vae/attention.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AttentionType(Enum):
|
| 9 |
+
"""Enum for specifying the attention mechanism type."""
|
| 10 |
+
|
| 11 |
+
VANILLA = "vanilla"
|
| 12 |
+
LINEAR = "linear"
|
| 13 |
+
NONE = "none"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AttnBlock(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
norm_type: NormType = NormType.GROUP,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
|
| 25 |
+
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
| 26 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 27 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 28 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 29 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
h_ = x
|
| 33 |
+
h_ = self.norm(h_)
|
| 34 |
+
q = self.q(h_)
|
| 35 |
+
k = self.k(h_)
|
| 36 |
+
v = self.v(h_)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
b, c, h, w = q.shape
|
| 40 |
+
q = q.reshape(b, c, h * w).contiguous()
|
| 41 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 42 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 43 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 44 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 45 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 46 |
+
|
| 47 |
+
# attend to values
|
| 48 |
+
v = v.reshape(b, c, h * w).contiguous()
|
| 49 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 50 |
+
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 51 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 52 |
+
|
| 53 |
+
h_ = self.proj_out(h_)
|
| 54 |
+
|
| 55 |
+
return x + h_
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make_attn(
|
| 59 |
+
in_channels: int,
|
| 60 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 61 |
+
norm_type: NormType = NormType.GROUP,
|
| 62 |
+
) -> torch.nn.Module:
|
| 63 |
+
match attn_type:
|
| 64 |
+
case AttentionType.VANILLA:
|
| 65 |
+
return AttnBlock(in_channels, norm_type=norm_type)
|
| 66 |
+
case AttentionType.NONE:
|
| 67 |
+
return torch.nn.Identity()
|
| 68 |
+
case AttentionType.LINEAR:
|
| 69 |
+
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
| 70 |
+
case _:
|
| 71 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
ltx2/ltx_core/model/audio_vae/audio_vae.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 7 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 8 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 9 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 10 |
+
from ltx_core.model.audio_vae.downsample import build_downsampling_path
|
| 11 |
+
from ltx_core.model.audio_vae.ops import AudioProcessor, PerChannelStatistics
|
| 12 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 13 |
+
from ltx_core.model.audio_vae.upsample import build_upsampling_path
|
| 14 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder
|
| 15 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 16 |
+
from ltx_core.types import Audio, AudioLatentShape
|
| 17 |
+
|
| 18 |
+
LATENT_DOWNSAMPLE_FACTOR = 4
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_mid_block(
|
| 22 |
+
channels: int,
|
| 23 |
+
temb_channels: int,
|
| 24 |
+
dropout: float,
|
| 25 |
+
norm_type: NormType,
|
| 26 |
+
causality_axis: CausalityAxis,
|
| 27 |
+
attn_type: AttentionType,
|
| 28 |
+
add_attention: bool,
|
| 29 |
+
) -> torch.nn.Module:
|
| 30 |
+
"""Build the middle block with two ResNet blocks and optional attention."""
|
| 31 |
+
mid = torch.nn.Module()
|
| 32 |
+
mid.block_1 = ResnetBlock(
|
| 33 |
+
in_channels=channels,
|
| 34 |
+
out_channels=channels,
|
| 35 |
+
temb_channels=temb_channels,
|
| 36 |
+
dropout=dropout,
|
| 37 |
+
norm_type=norm_type,
|
| 38 |
+
causality_axis=causality_axis,
|
| 39 |
+
)
|
| 40 |
+
mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
|
| 41 |
+
mid.block_2 = ResnetBlock(
|
| 42 |
+
in_channels=channels,
|
| 43 |
+
out_channels=channels,
|
| 44 |
+
temb_channels=temb_channels,
|
| 45 |
+
dropout=dropout,
|
| 46 |
+
norm_type=norm_type,
|
| 47 |
+
causality_axis=causality_axis,
|
| 48 |
+
)
|
| 49 |
+
return mid
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Run features through the middle block."""
|
| 54 |
+
features = mid.block_1(features, temb=None)
|
| 55 |
+
features = mid.attn_1(features)
|
| 56 |
+
return mid.block_2(features, temb=None)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AudioEncoder(torch.nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Encoder that compresses audio spectrograms into latent representations.
|
| 62 |
+
The encoder uses a series of downsampling blocks with residual connections,
|
| 63 |
+
attention mechanisms, and configurable causal convolutions.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__( # noqa: PLR0913
|
| 67 |
+
self,
|
| 68 |
+
*,
|
| 69 |
+
ch: int,
|
| 70 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 71 |
+
num_res_blocks: int,
|
| 72 |
+
attn_resolutions: Set[int],
|
| 73 |
+
dropout: float = 0.0,
|
| 74 |
+
resamp_with_conv: bool = True,
|
| 75 |
+
in_channels: int,
|
| 76 |
+
resolution: int,
|
| 77 |
+
z_channels: int,
|
| 78 |
+
double_z: bool = True,
|
| 79 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 80 |
+
mid_block_add_attention: bool = True,
|
| 81 |
+
norm_type: NormType = NormType.GROUP,
|
| 82 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 83 |
+
sample_rate: int = 16000,
|
| 84 |
+
mel_hop_length: int = 160,
|
| 85 |
+
n_fft: int = 1024,
|
| 86 |
+
is_causal: bool = True,
|
| 87 |
+
mel_bins: int = 64,
|
| 88 |
+
**_ignore_kwargs,
|
| 89 |
+
) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Initialize the Encoder.
|
| 92 |
+
Args:
|
| 93 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 94 |
+
(audio_vae.model.params.ddconfig):
|
| 95 |
+
ch: Base number of feature channels used in the first convolution layer.
|
| 96 |
+
ch_mult: Multiplicative factors for the number of channels at each resolution level.
|
| 97 |
+
num_res_blocks: Number of residual blocks to use at each resolution level.
|
| 98 |
+
attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
|
| 99 |
+
resolution: Input spatial resolution of the spectrogram (height, width).
|
| 100 |
+
z_channels: Number of channels in the latent representation.
|
| 101 |
+
norm_type: Normalization layer type to use within the network (e.g., group, batch).
|
| 102 |
+
causality_axis: Axis along which convolutions should be causal (e.g., time axis).
|
| 103 |
+
sample_rate: Audio sample rate in Hz for the input signals.
|
| 104 |
+
mel_hop_length: Hop length used when computing the mel spectrogram.
|
| 105 |
+
n_fft: FFT size used to compute the spectrogram.
|
| 106 |
+
mel_bins: Number of mel-frequency bins in the input spectrogram.
|
| 107 |
+
in_channels: Number of channels in the input spectrogram tensor.
|
| 108 |
+
double_z: If True, predict both mean and log-variance (doubling latent channels).
|
| 109 |
+
is_causal: If True, use causal convolutions suitable for streaming setups.
|
| 110 |
+
dropout: Dropout probability used in residual and mid blocks.
|
| 111 |
+
attn_type: Type of attention mechanism to use in attention blocks.
|
| 112 |
+
resamp_with_conv: If True, perform resolution changes using strided convolutions.
|
| 113 |
+
mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
|
| 114 |
+
"""
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 118 |
+
self.sample_rate = sample_rate
|
| 119 |
+
self.mel_hop_length = mel_hop_length
|
| 120 |
+
self.n_fft = n_fft
|
| 121 |
+
self.is_causal = is_causal
|
| 122 |
+
self.mel_bins = mel_bins
|
| 123 |
+
|
| 124 |
+
self.patchifier = AudioPatchifier(
|
| 125 |
+
patch_size=1,
|
| 126 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 127 |
+
sample_rate=sample_rate,
|
| 128 |
+
hop_length=mel_hop_length,
|
| 129 |
+
is_causal=is_causal,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.ch = ch
|
| 133 |
+
self.temb_ch = 0
|
| 134 |
+
self.num_resolutions = len(ch_mult)
|
| 135 |
+
self.num_res_blocks = num_res_blocks
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
self.z_channels = z_channels
|
| 139 |
+
self.double_z = double_z
|
| 140 |
+
self.norm_type = norm_type
|
| 141 |
+
self.causality_axis = causality_axis
|
| 142 |
+
self.attn_type = attn_type
|
| 143 |
+
|
| 144 |
+
# downsampling
|
| 145 |
+
self.conv_in = make_conv2d(
|
| 146 |
+
in_channels,
|
| 147 |
+
self.ch,
|
| 148 |
+
kernel_size=3,
|
| 149 |
+
stride=1,
|
| 150 |
+
causality_axis=self.causality_axis,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.non_linearity = torch.nn.SiLU()
|
| 154 |
+
|
| 155 |
+
self.down, block_in = build_downsampling_path(
|
| 156 |
+
ch=ch,
|
| 157 |
+
ch_mult=ch_mult,
|
| 158 |
+
num_resolutions=self.num_resolutions,
|
| 159 |
+
num_res_blocks=num_res_blocks,
|
| 160 |
+
resolution=resolution,
|
| 161 |
+
temb_channels=self.temb_ch,
|
| 162 |
+
dropout=dropout,
|
| 163 |
+
norm_type=self.norm_type,
|
| 164 |
+
causality_axis=self.causality_axis,
|
| 165 |
+
attn_type=self.attn_type,
|
| 166 |
+
attn_resolutions=attn_resolutions,
|
| 167 |
+
resamp_with_conv=resamp_with_conv,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.mid = build_mid_block(
|
| 171 |
+
channels=block_in,
|
| 172 |
+
temb_channels=self.temb_ch,
|
| 173 |
+
dropout=dropout,
|
| 174 |
+
norm_type=self.norm_type,
|
| 175 |
+
causality_axis=self.causality_axis,
|
| 176 |
+
attn_type=self.attn_type,
|
| 177 |
+
add_attention=mid_block_add_attention,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
| 181 |
+
self.conv_out = make_conv2d(
|
| 182 |
+
block_in,
|
| 183 |
+
2 * z_channels if double_z else z_channels,
|
| 184 |
+
kernel_size=3,
|
| 185 |
+
stride=1,
|
| 186 |
+
causality_axis=self.causality_axis,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Encode audio spectrogram into latent representations.
|
| 192 |
+
Args:
|
| 193 |
+
spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
|
| 194 |
+
Returns:
|
| 195 |
+
Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 196 |
+
"""
|
| 197 |
+
h = self.conv_in(spectrogram)
|
| 198 |
+
h = self._run_downsampling_path(h)
|
| 199 |
+
h = run_mid_block(self.mid, h)
|
| 200 |
+
h = self._finalize_output(h)
|
| 201 |
+
|
| 202 |
+
return self._normalize_latents(h)
|
| 203 |
+
|
| 204 |
+
def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 205 |
+
for level in range(self.num_resolutions):
|
| 206 |
+
stage = self.down[level]
|
| 207 |
+
for block_idx in range(self.num_res_blocks):
|
| 208 |
+
h = stage.block[block_idx](h, temb=None)
|
| 209 |
+
if stage.attn:
|
| 210 |
+
h = stage.attn[block_idx](h)
|
| 211 |
+
|
| 212 |
+
if level != self.num_resolutions - 1:
|
| 213 |
+
h = stage.downsample(h)
|
| 214 |
+
|
| 215 |
+
return h
|
| 216 |
+
|
| 217 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 218 |
+
h = self.norm_out(h)
|
| 219 |
+
h = self.non_linearity(h)
|
| 220 |
+
return self.conv_out(h)
|
| 221 |
+
|
| 222 |
+
def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
|
| 223 |
+
"""
|
| 224 |
+
Normalize encoder latents using per-channel statistics.
|
| 225 |
+
When the encoder is configured with ``double_z=True``, the final
|
| 226 |
+
convolution produces twice the number of latent channels, typically
|
| 227 |
+
interpreted as two concatenated tensors along the channel dimension
|
| 228 |
+
(e.g., mean and variance or other auxiliary parameters).
|
| 229 |
+
This method intentionally uses only the first half of the channels
|
| 230 |
+
(the "mean" component) as input to the patchifier and normalization
|
| 231 |
+
logic. The remaining channels are left unchanged by this method and
|
| 232 |
+
are expected to be consumed elsewhere in the VAE pipeline.
|
| 233 |
+
If ``double_z=False``, the encoder output already contains only the
|
| 234 |
+
mean latents and the chunking operation simply returns that tensor.
|
| 235 |
+
"""
|
| 236 |
+
means = torch.chunk(latent_output, 2, dim=1)[0]
|
| 237 |
+
latent_shape = AudioLatentShape(
|
| 238 |
+
batch=means.shape[0],
|
| 239 |
+
channels=means.shape[1],
|
| 240 |
+
frames=means.shape[2],
|
| 241 |
+
mel_bins=means.shape[3],
|
| 242 |
+
)
|
| 243 |
+
latent_patched = self.patchifier.patchify(means)
|
| 244 |
+
latent_normalized = self.per_channel_statistics.normalize(latent_patched)
|
| 245 |
+
return self.patchifier.unpatchify(latent_normalized, latent_shape)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def encode_audio(
|
| 249 |
+
audio: Audio,
|
| 250 |
+
audio_encoder: AudioEncoder,
|
| 251 |
+
audio_processor: AudioProcessor | None = None,
|
| 252 |
+
) -> torch.Tensor:
|
| 253 |
+
"""Encode audio waveform into latent representation.
|
| 254 |
+
Args:
|
| 255 |
+
audio: Audio container with waveform tensor of shape (batch, channels, samples) and sampling rate.
|
| 256 |
+
audio_encoder: Audio encoder model
|
| 257 |
+
audio_processor: Audio processor model (optional, if not provided, it will be created from the audio encoder)
|
| 258 |
+
"""
|
| 259 |
+
dtype = next(audio_encoder.parameters()).dtype
|
| 260 |
+
device = next(audio_encoder.parameters()).device
|
| 261 |
+
|
| 262 |
+
if audio_processor is None:
|
| 263 |
+
audio_processor = AudioProcessor(
|
| 264 |
+
target_sample_rate=audio_encoder.sample_rate,
|
| 265 |
+
mel_bins=audio_encoder.mel_bins,
|
| 266 |
+
mel_hop_length=audio_encoder.mel_hop_length,
|
| 267 |
+
n_fft=audio_encoder.n_fft,
|
| 268 |
+
).to(device=device)
|
| 269 |
+
|
| 270 |
+
mel_spectrogram = audio_processor.waveform_to_mel(audio.to(device=device))
|
| 271 |
+
|
| 272 |
+
latent = audio_encoder(mel_spectrogram.to(dtype=dtype))
|
| 273 |
+
return latent
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class AudioDecoder(torch.nn.Module):
|
| 277 |
+
"""
|
| 278 |
+
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
| 279 |
+
The decoder mirrors the encoder structure with configurable channel multipliers,
|
| 280 |
+
attention resolutions, and causal convolutions.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__( # noqa: PLR0913
|
| 284 |
+
self,
|
| 285 |
+
*,
|
| 286 |
+
ch: int,
|
| 287 |
+
out_ch: int,
|
| 288 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 289 |
+
num_res_blocks: int,
|
| 290 |
+
attn_resolutions: Set[int],
|
| 291 |
+
resolution: int,
|
| 292 |
+
z_channels: int,
|
| 293 |
+
norm_type: NormType = NormType.GROUP,
|
| 294 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 295 |
+
dropout: float = 0.0,
|
| 296 |
+
mid_block_add_attention: bool = True,
|
| 297 |
+
sample_rate: int = 16000,
|
| 298 |
+
mel_hop_length: int = 160,
|
| 299 |
+
is_causal: bool = True,
|
| 300 |
+
mel_bins: int | None = None,
|
| 301 |
+
) -> None:
|
| 302 |
+
"""
|
| 303 |
+
Initialize the Decoder.
|
| 304 |
+
Args:
|
| 305 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 306 |
+
(audio_vae.model.params.ddconfig):
|
| 307 |
+
- ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
|
| 308 |
+
- resolution, z_channels
|
| 309 |
+
- norm_type, causality_axis
|
| 310 |
+
"""
|
| 311 |
+
super().__init__()
|
| 312 |
+
|
| 313 |
+
# Internal behavioural defaults that are not driven by the checkpoint.
|
| 314 |
+
resamp_with_conv = True
|
| 315 |
+
attn_type = AttentionType.VANILLA
|
| 316 |
+
|
| 317 |
+
# Per-channel statistics for denormalizing latents
|
| 318 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 319 |
+
self.sample_rate = sample_rate
|
| 320 |
+
self.mel_hop_length = mel_hop_length
|
| 321 |
+
self.is_causal = is_causal
|
| 322 |
+
self.mel_bins = mel_bins
|
| 323 |
+
self.patchifier = AudioPatchifier(
|
| 324 |
+
patch_size=1,
|
| 325 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 326 |
+
sample_rate=sample_rate,
|
| 327 |
+
hop_length=mel_hop_length,
|
| 328 |
+
is_causal=is_causal,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.ch = ch
|
| 332 |
+
self.temb_ch = 0
|
| 333 |
+
self.num_resolutions = len(ch_mult)
|
| 334 |
+
self.num_res_blocks = num_res_blocks
|
| 335 |
+
self.resolution = resolution
|
| 336 |
+
self.out_ch = out_ch
|
| 337 |
+
self.give_pre_end = False
|
| 338 |
+
self.tanh_out = False
|
| 339 |
+
self.norm_type = norm_type
|
| 340 |
+
self.z_channels = z_channels
|
| 341 |
+
self.channel_multipliers = ch_mult
|
| 342 |
+
self.attn_resolutions = attn_resolutions
|
| 343 |
+
self.causality_axis = causality_axis
|
| 344 |
+
self.attn_type = attn_type
|
| 345 |
+
|
| 346 |
+
base_block_channels = ch * self.channel_multipliers[-1]
|
| 347 |
+
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
| 348 |
+
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
| 349 |
+
|
| 350 |
+
self.conv_in = make_conv2d(
|
| 351 |
+
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 352 |
+
)
|
| 353 |
+
self.non_linearity = torch.nn.SiLU()
|
| 354 |
+
self.mid = build_mid_block(
|
| 355 |
+
channels=base_block_channels,
|
| 356 |
+
temb_channels=self.temb_ch,
|
| 357 |
+
dropout=dropout,
|
| 358 |
+
norm_type=self.norm_type,
|
| 359 |
+
causality_axis=self.causality_axis,
|
| 360 |
+
attn_type=self.attn_type,
|
| 361 |
+
add_attention=mid_block_add_attention,
|
| 362 |
+
)
|
| 363 |
+
self.up, final_block_channels = build_upsampling_path(
|
| 364 |
+
ch=ch,
|
| 365 |
+
ch_mult=ch_mult,
|
| 366 |
+
num_resolutions=self.num_resolutions,
|
| 367 |
+
num_res_blocks=num_res_blocks,
|
| 368 |
+
resolution=resolution,
|
| 369 |
+
temb_channels=self.temb_ch,
|
| 370 |
+
dropout=dropout,
|
| 371 |
+
norm_type=self.norm_type,
|
| 372 |
+
causality_axis=self.causality_axis,
|
| 373 |
+
attn_type=self.attn_type,
|
| 374 |
+
attn_resolutions=attn_resolutions,
|
| 375 |
+
resamp_with_conv=resamp_with_conv,
|
| 376 |
+
initial_block_channels=base_block_channels,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
| 380 |
+
self.conv_out = make_conv2d(
|
| 381 |
+
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
"""
|
| 386 |
+
Decode latent features back to audio spectrograms.
|
| 387 |
+
Args:
|
| 388 |
+
sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 389 |
+
Returns:
|
| 390 |
+
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
| 391 |
+
"""
|
| 392 |
+
sample, target_shape = self._denormalize_latents(sample)
|
| 393 |
+
|
| 394 |
+
h = self.conv_in(sample)
|
| 395 |
+
h = run_mid_block(self.mid, h)
|
| 396 |
+
h = self._run_upsampling_path(h)
|
| 397 |
+
h = self._finalize_output(h)
|
| 398 |
+
|
| 399 |
+
return self._adjust_output_shape(h, target_shape)
|
| 400 |
+
|
| 401 |
+
def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
|
| 402 |
+
latent_shape = AudioLatentShape(
|
| 403 |
+
batch=sample.shape[0],
|
| 404 |
+
channels=sample.shape[1],
|
| 405 |
+
frames=sample.shape[2],
|
| 406 |
+
mel_bins=sample.shape[3],
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
sample_patched = self.patchifier.patchify(sample)
|
| 410 |
+
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
| 411 |
+
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
| 412 |
+
|
| 413 |
+
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
| 414 |
+
if self.causality_axis != CausalityAxis.NONE:
|
| 415 |
+
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
| 416 |
+
|
| 417 |
+
target_shape = AudioLatentShape(
|
| 418 |
+
batch=latent_shape.batch,
|
| 419 |
+
channels=self.out_ch,
|
| 420 |
+
frames=target_frames,
|
| 421 |
+
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return sample, target_shape
|
| 425 |
+
|
| 426 |
+
def _adjust_output_shape(
|
| 427 |
+
self,
|
| 428 |
+
decoded_output: torch.Tensor,
|
| 429 |
+
target_shape: AudioLatentShape,
|
| 430 |
+
) -> torch.Tensor:
|
| 431 |
+
"""
|
| 432 |
+
Adjust output shape to match target dimensions for variable-length audio.
|
| 433 |
+
This function handles the common case where decoded audio spectrograms need to be
|
| 434 |
+
resized to match a specific target shape.
|
| 435 |
+
Args:
|
| 436 |
+
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
| 437 |
+
target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
|
| 438 |
+
Returns:
|
| 439 |
+
Tensor adjusted to match target_shape exactly
|
| 440 |
+
"""
|
| 441 |
+
# Current output shape: (batch, channels, time, frequency)
|
| 442 |
+
_, _, current_time, current_freq = decoded_output.shape
|
| 443 |
+
target_channels = target_shape.channels
|
| 444 |
+
target_time = target_shape.frames
|
| 445 |
+
target_freq = target_shape.mel_bins
|
| 446 |
+
|
| 447 |
+
# Step 1: Crop first to avoid exceeding target dimensions
|
| 448 |
+
decoded_output = decoded_output[
|
| 449 |
+
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
# Step 2: Calculate padding needed for time and frequency dimensions
|
| 453 |
+
time_padding_needed = target_time - decoded_output.shape[2]
|
| 454 |
+
freq_padding_needed = target_freq - decoded_output.shape[3]
|
| 455 |
+
|
| 456 |
+
# Step 3: Apply padding if needed
|
| 457 |
+
if time_padding_needed > 0 or freq_padding_needed > 0:
|
| 458 |
+
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
| 459 |
+
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
| 460 |
+
padding = (
|
| 461 |
+
0,
|
| 462 |
+
max(freq_padding_needed, 0), # frequency padding (left, right)
|
| 463 |
+
0,
|
| 464 |
+
max(time_padding_needed, 0), # time padding (top, bottom)
|
| 465 |
+
)
|
| 466 |
+
decoded_output = F.pad(decoded_output, padding)
|
| 467 |
+
|
| 468 |
+
# Step 4: Final safety crop to ensure exact target shape
|
| 469 |
+
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
| 470 |
+
|
| 471 |
+
return decoded_output
|
| 472 |
+
|
| 473 |
+
def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 474 |
+
for level in reversed(range(self.num_resolutions)):
|
| 475 |
+
stage = self.up[level]
|
| 476 |
+
for block_idx, block in enumerate(stage.block):
|
| 477 |
+
h = block(h, temb=None)
|
| 478 |
+
if stage.attn:
|
| 479 |
+
h = stage.attn[block_idx](h)
|
| 480 |
+
|
| 481 |
+
if level != 0 and hasattr(stage, "upsample"):
|
| 482 |
+
h = stage.upsample(h)
|
| 483 |
+
|
| 484 |
+
return h
|
| 485 |
+
|
| 486 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 487 |
+
if self.give_pre_end:
|
| 488 |
+
return h
|
| 489 |
+
|
| 490 |
+
h = self.norm_out(h)
|
| 491 |
+
h = self.non_linearity(h)
|
| 492 |
+
h = self.conv_out(h)
|
| 493 |
+
return torch.tanh(h) if self.tanh_out else h
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> Audio:
|
| 497 |
+
"""
|
| 498 |
+
Decode an audio latent representation using the provided audio decoder and vocoder.
|
| 499 |
+
Args:
|
| 500 |
+
latent: Input audio latent tensor.
|
| 501 |
+
audio_decoder: Model to decode the latent to waveform features.
|
| 502 |
+
vocoder: Model to convert decoded features to audio waveform.
|
| 503 |
+
Returns:
|
| 504 |
+
Decoded audio with waveform and sampling rate.
|
| 505 |
+
"""
|
| 506 |
+
decoded_audio = audio_decoder(latent)
|
| 507 |
+
waveform = vocoder(decoded_audio).squeeze(0).float()
|
| 508 |
+
return Audio(waveform=waveform, sampling_rate=vocoder.output_sampling_rate)
|
ltx2/ltx_core/model/audio_vae/causal_conv_2d.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalConv2d(torch.nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A causal 2D convolution.
|
| 10 |
+
This layer ensures that the output at time `t` only depends on inputs
|
| 11 |
+
at time `t` and earlier. It achieves this by applying asymmetric padding
|
| 12 |
+
to the time dimension (width) before the convolution.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels: int,
|
| 18 |
+
out_channels: int,
|
| 19 |
+
kernel_size: int | tuple[int, int],
|
| 20 |
+
stride: int = 1,
|
| 21 |
+
dilation: int | tuple[int, int] = 1,
|
| 22 |
+
groups: int = 1,
|
| 23 |
+
bias: bool = True,
|
| 24 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.causality_axis = causality_axis
|
| 29 |
+
|
| 30 |
+
# Ensure kernel_size and dilation are tuples
|
| 31 |
+
kernel_size = torch.nn.modules.utils._pair(kernel_size)
|
| 32 |
+
dilation = torch.nn.modules.utils._pair(dilation)
|
| 33 |
+
|
| 34 |
+
# Calculate padding dimensions
|
| 35 |
+
pad_h = (kernel_size[0] - 1) * dilation[0]
|
| 36 |
+
pad_w = (kernel_size[1] - 1) * dilation[1]
|
| 37 |
+
|
| 38 |
+
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
| 39 |
+
match self.causality_axis:
|
| 40 |
+
case CausalityAxis.NONE:
|
| 41 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
| 42 |
+
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
| 43 |
+
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
| 44 |
+
case CausalityAxis.HEIGHT:
|
| 45 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
| 46 |
+
case _:
|
| 47 |
+
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
| 48 |
+
|
| 49 |
+
# The internal convolution layer uses no padding, as we handle it manually
|
| 50 |
+
self.conv = torch.nn.Conv2d(
|
| 51 |
+
in_channels,
|
| 52 |
+
out_channels,
|
| 53 |
+
kernel_size,
|
| 54 |
+
stride=stride,
|
| 55 |
+
padding=0,
|
| 56 |
+
dilation=dilation,
|
| 57 |
+
groups=groups,
|
| 58 |
+
bias=bias,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
# Apply causal padding before convolution
|
| 63 |
+
x = F.pad(x, self.padding)
|
| 64 |
+
return self.conv(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_conv2d(
|
| 68 |
+
in_channels: int,
|
| 69 |
+
out_channels: int,
|
| 70 |
+
kernel_size: int | tuple[int, int],
|
| 71 |
+
stride: int = 1,
|
| 72 |
+
padding: tuple[int, int, int, int] | None = None,
|
| 73 |
+
dilation: int = 1,
|
| 74 |
+
groups: int = 1,
|
| 75 |
+
bias: bool = True,
|
| 76 |
+
causality_axis: CausalityAxis | None = None,
|
| 77 |
+
) -> torch.nn.Module:
|
| 78 |
+
"""
|
| 79 |
+
Create a 2D convolution layer that can be either causal or non-causal.
|
| 80 |
+
Args:
|
| 81 |
+
in_channels: Number of input channels
|
| 82 |
+
out_channels: Number of output channels
|
| 83 |
+
kernel_size: Size of the convolution kernel
|
| 84 |
+
stride: Convolution stride
|
| 85 |
+
padding: Padding (if None, will be calculated based on causal flag)
|
| 86 |
+
dilation: Dilation rate
|
| 87 |
+
groups: Number of groups for grouped convolution
|
| 88 |
+
bias: Whether to use bias
|
| 89 |
+
causality_axis: Dimension along which to apply causality.
|
| 90 |
+
Returns:
|
| 91 |
+
Either a regular Conv2d or CausalConv2d layer
|
| 92 |
+
"""
|
| 93 |
+
if causality_axis is not None:
|
| 94 |
+
# For causal convolution, padding is handled internally by CausalConv2d
|
| 95 |
+
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
| 96 |
+
else:
|
| 97 |
+
# For non-causal convolution, use symmetric padding if not specified
|
| 98 |
+
if padding is None:
|
| 99 |
+
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
|
| 100 |
+
|
| 101 |
+
return torch.nn.Conv2d(
|
| 102 |
+
in_channels,
|
| 103 |
+
out_channels,
|
| 104 |
+
kernel_size,
|
| 105 |
+
stride,
|
| 106 |
+
padding,
|
| 107 |
+
dilation,
|
| 108 |
+
groups,
|
| 109 |
+
bias,
|
| 110 |
+
)
|
ltx2/ltx_core/model/audio_vae/causality_axis.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CausalityAxis(Enum):
|
| 5 |
+
"""Enum for specifying the causality axis in causal convolutions."""
|
| 6 |
+
|
| 7 |
+
NONE = None
|
| 8 |
+
WIDTH = "width"
|
| 9 |
+
HEIGHT = "height"
|
| 10 |
+
WIDTH_COMPATIBILITY = "width-compatibility"
|
ltx2/ltx_core/model/audio_vae/downsample.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 8 |
+
from ltx_core.model.common.normalization import NormType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Downsample(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
A downsampling layer that can use either a strided convolution
|
| 14 |
+
or average pooling. Supports standard and causal padding for the
|
| 15 |
+
convolutional mode.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
with_conv: bool,
|
| 22 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.with_conv = with_conv
|
| 26 |
+
self.causality_axis = causality_axis
|
| 27 |
+
|
| 28 |
+
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
| 29 |
+
raise ValueError("causality is only supported when `with_conv=True`.")
|
| 30 |
+
|
| 31 |
+
if self.with_conv:
|
| 32 |
+
# Do time downsampling here
|
| 33 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 34 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
if self.with_conv:
|
| 38 |
+
# Padding tuple is in the order: (left, right, top, bottom).
|
| 39 |
+
match self.causality_axis:
|
| 40 |
+
case CausalityAxis.NONE:
|
| 41 |
+
pad = (0, 1, 0, 1)
|
| 42 |
+
case CausalityAxis.WIDTH:
|
| 43 |
+
pad = (2, 0, 0, 1)
|
| 44 |
+
case CausalityAxis.HEIGHT:
|
| 45 |
+
pad = (0, 1, 2, 0)
|
| 46 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 47 |
+
pad = (1, 0, 0, 1)
|
| 48 |
+
case _:
|
| 49 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 50 |
+
|
| 51 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 52 |
+
x = self.conv(x)
|
| 53 |
+
else:
|
| 54 |
+
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
| 55 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 56 |
+
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_downsampling_path( # noqa: PLR0913
|
| 61 |
+
*,
|
| 62 |
+
ch: int,
|
| 63 |
+
ch_mult: Tuple[int, ...],
|
| 64 |
+
num_resolutions: int,
|
| 65 |
+
num_res_blocks: int,
|
| 66 |
+
resolution: int,
|
| 67 |
+
temb_channels: int,
|
| 68 |
+
dropout: float,
|
| 69 |
+
norm_type: NormType,
|
| 70 |
+
causality_axis: CausalityAxis,
|
| 71 |
+
attn_type: AttentionType,
|
| 72 |
+
attn_resolutions: Set[int],
|
| 73 |
+
resamp_with_conv: bool,
|
| 74 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 75 |
+
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
| 76 |
+
down_modules = torch.nn.ModuleList()
|
| 77 |
+
curr_res = resolution
|
| 78 |
+
in_ch_mult = (1, *tuple(ch_mult))
|
| 79 |
+
block_in = ch
|
| 80 |
+
|
| 81 |
+
for i_level in range(num_resolutions):
|
| 82 |
+
block = torch.nn.ModuleList()
|
| 83 |
+
attn = torch.nn.ModuleList()
|
| 84 |
+
block_in = ch * in_ch_mult[i_level]
|
| 85 |
+
block_out = ch * ch_mult[i_level]
|
| 86 |
+
|
| 87 |
+
for _ in range(num_res_blocks):
|
| 88 |
+
block.append(
|
| 89 |
+
ResnetBlock(
|
| 90 |
+
in_channels=block_in,
|
| 91 |
+
out_channels=block_out,
|
| 92 |
+
temb_channels=temb_channels,
|
| 93 |
+
dropout=dropout,
|
| 94 |
+
norm_type=norm_type,
|
| 95 |
+
causality_axis=causality_axis,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
block_in = block_out
|
| 99 |
+
if curr_res in attn_resolutions:
|
| 100 |
+
attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 101 |
+
|
| 102 |
+
down = torch.nn.Module()
|
| 103 |
+
down.block = block
|
| 104 |
+
down.attn = attn
|
| 105 |
+
if i_level != num_resolutions - 1:
|
| 106 |
+
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 107 |
+
curr_res = curr_res // 2
|
| 108 |
+
down_modules.append(down)
|
| 109 |
+
|
| 110 |
+
return down_modules, block_in
|
ltx2/ltx_core/model/audio_vae/model_configurator.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps
|
| 4 |
+
from ltx_core.model.audio_vae.attention import AttentionType
|
| 5 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.audio_vae.vocoder import MelSTFT, Vocoder, VocoderWithBWE
|
| 8 |
+
from ltx_core.model.common.normalization import NormType
|
| 9 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 10 |
+
from ltx_core.utils import check_config_value
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _vocoder_from_config(
|
| 14 |
+
cfg: dict,
|
| 15 |
+
apply_final_activation: bool = True,
|
| 16 |
+
output_sampling_rate: int | None = None,
|
| 17 |
+
) -> Vocoder:
|
| 18 |
+
"""Instantiate a Vocoder from a flat config dict.
|
| 19 |
+
Args:
|
| 20 |
+
cfg: Vocoder config dict (keys match Vocoder constructor args).
|
| 21 |
+
apply_final_activation: Whether to apply tanh/clamp at the output.
|
| 22 |
+
output_sampling_rate: Explicit override for the output sample rate.
|
| 23 |
+
When None, reads from cfg["output_sampling_rate"] (default 24000).
|
| 24 |
+
"""
|
| 25 |
+
return Vocoder(
|
| 26 |
+
resblock_kernel_sizes=cfg.get("resblock_kernel_sizes", [3, 7, 11]),
|
| 27 |
+
upsample_rates=cfg.get("upsample_rates", [6, 5, 2, 2, 2]),
|
| 28 |
+
upsample_kernel_sizes=cfg.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
|
| 29 |
+
resblock_dilation_sizes=cfg.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
|
| 30 |
+
upsample_initial_channel=cfg.get("upsample_initial_channel", 1024),
|
| 31 |
+
resblock=cfg.get("resblock", "1"),
|
| 32 |
+
output_sampling_rate=(
|
| 33 |
+
output_sampling_rate if output_sampling_rate is not None else cfg.get("output_sampling_rate", 24000)
|
| 34 |
+
),
|
| 35 |
+
activation=cfg.get("activation", "snake"),
|
| 36 |
+
use_tanh_at_final=cfg.get("use_tanh_at_final", True),
|
| 37 |
+
apply_final_activation=apply_final_activation,
|
| 38 |
+
use_bias_at_final=cfg.get("use_bias_at_final", True),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class VocoderConfigurator(ModelConfigurator[Vocoder]):
|
| 43 |
+
"""Configurator that auto-detects the checkpoint format.
|
| 44 |
+
Returns a plain Vocoder for pre-ltx-2.3 checkpoints (flat config) or a
|
| 45 |
+
VocoderWithBWE for ltx-2.3+ checkpoints (nested "vocoder" + "bwe" config).
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_config(cls: type[Vocoder], config: dict) -> Vocoder | VocoderWithBWE:
|
| 50 |
+
cfg = config.get("vocoder", {})
|
| 51 |
+
|
| 52 |
+
if "bwe" not in cfg:
|
| 53 |
+
check_config_value(cfg, "resblock", "1")
|
| 54 |
+
check_config_value(cfg, "stereo", True)
|
| 55 |
+
return _vocoder_from_config(cfg)
|
| 56 |
+
|
| 57 |
+
vocoder_cfg = cfg.get("vocoder", {})
|
| 58 |
+
bwe_cfg = cfg["bwe"]
|
| 59 |
+
|
| 60 |
+
check_config_value(vocoder_cfg, "resblock", "AMP1")
|
| 61 |
+
check_config_value(vocoder_cfg, "stereo", True)
|
| 62 |
+
check_config_value(vocoder_cfg, "activation", "snakebeta")
|
| 63 |
+
check_config_value(bwe_cfg, "resblock", "AMP1")
|
| 64 |
+
check_config_value(bwe_cfg, "stereo", True)
|
| 65 |
+
check_config_value(bwe_cfg, "activation", "snakebeta")
|
| 66 |
+
|
| 67 |
+
vocoder = _vocoder_from_config(
|
| 68 |
+
vocoder_cfg,
|
| 69 |
+
output_sampling_rate=bwe_cfg["input_sampling_rate"],
|
| 70 |
+
)
|
| 71 |
+
bwe_generator = _vocoder_from_config(
|
| 72 |
+
bwe_cfg,
|
| 73 |
+
apply_final_activation=False,
|
| 74 |
+
output_sampling_rate=bwe_cfg["output_sampling_rate"],
|
| 75 |
+
)
|
| 76 |
+
mel_stft = MelSTFT(
|
| 77 |
+
filter_length=bwe_cfg["n_fft"],
|
| 78 |
+
hop_length=bwe_cfg["hop_length"],
|
| 79 |
+
win_length=bwe_cfg["n_fft"],
|
| 80 |
+
n_mel_channels=bwe_cfg["num_mels"],
|
| 81 |
+
)
|
| 82 |
+
return VocoderWithBWE(
|
| 83 |
+
vocoder=vocoder,
|
| 84 |
+
bwe_generator=bwe_generator,
|
| 85 |
+
mel_stft=mel_stft,
|
| 86 |
+
input_sampling_rate=bwe_cfg["input_sampling_rate"],
|
| 87 |
+
output_sampling_rate=bwe_cfg["output_sampling_rate"],
|
| 88 |
+
hop_length=bwe_cfg["hop_length"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _strip_vocoder_prefix(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 93 |
+
"""Strip the leading 'vocoder.' prefix exactly once.
|
| 94 |
+
Uses removeprefix instead of str.replace so that BWE keys like
|
| 95 |
+
'vocoder.vocoder.conv_pre' become 'vocoder.conv_pre' (not 'conv_pre').
|
| 96 |
+
Works identically for legacy keys like 'vocoder.conv_pre' → 'conv_pre'.
|
| 97 |
+
"""
|
| 98 |
+
return [KeyValueOperationResult(key.removeprefix("vocoder."), value)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
VOCODER_COMFY_KEYS_FILTER = (
|
| 102 |
+
SDOps("VOCODER_COMFY_KEYS_FILTER")
|
| 103 |
+
.with_matching(prefix="vocoder.")
|
| 104 |
+
.with_kv_operation(operation=_strip_vocoder_prefix, key_prefix="vocoder.")
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
|
| 111 |
+
audio_vae_cfg = config.get("audio_vae", {})
|
| 112 |
+
model_cfg = audio_vae_cfg.get("model", {})
|
| 113 |
+
model_params = model_cfg.get("params", {})
|
| 114 |
+
ddconfig = model_params.get("ddconfig", {})
|
| 115 |
+
preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
|
| 116 |
+
stft_cfg = preprocessing_cfg.get("stft", {})
|
| 117 |
+
mel_cfg = preprocessing_cfg.get("mel", {})
|
| 118 |
+
variables_cfg = audio_vae_cfg.get("variables", {})
|
| 119 |
+
|
| 120 |
+
sample_rate = model_params.get("sampling_rate", 16000)
|
| 121 |
+
mel_hop_length = stft_cfg.get("hop_length", 160)
|
| 122 |
+
is_causal = stft_cfg.get("causal", True)
|
| 123 |
+
mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
|
| 124 |
+
|
| 125 |
+
return AudioDecoder(
|
| 126 |
+
ch=ddconfig.get("ch", 128),
|
| 127 |
+
out_ch=ddconfig.get("out_ch", 2),
|
| 128 |
+
ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
|
| 129 |
+
num_res_blocks=ddconfig.get("num_res_blocks", 2),
|
| 130 |
+
attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
|
| 131 |
+
resolution=ddconfig.get("resolution", 256),
|
| 132 |
+
z_channels=ddconfig.get("z_channels", 8),
|
| 133 |
+
norm_type=NormType(ddconfig.get("norm_type", "pixel")),
|
| 134 |
+
causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
|
| 135 |
+
dropout=ddconfig.get("dropout", 0.0),
|
| 136 |
+
mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
|
| 137 |
+
sample_rate=sample_rate,
|
| 138 |
+
mel_hop_length=mel_hop_length,
|
| 139 |
+
is_causal=is_causal,
|
| 140 |
+
mel_bins=mel_bins,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
|
| 145 |
+
@classmethod
|
| 146 |
+
def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
|
| 147 |
+
audio_vae_cfg = config.get("audio_vae", {})
|
| 148 |
+
model_cfg = audio_vae_cfg.get("model", {})
|
| 149 |
+
model_params = model_cfg.get("params", {})
|
| 150 |
+
ddconfig = model_params.get("ddconfig", {})
|
| 151 |
+
preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
|
| 152 |
+
stft_cfg = preprocessing_cfg.get("stft", {})
|
| 153 |
+
mel_cfg = preprocessing_cfg.get("mel", {})
|
| 154 |
+
variables_cfg = audio_vae_cfg.get("variables", {})
|
| 155 |
+
|
| 156 |
+
sample_rate = model_params.get("sampling_rate", 16000)
|
| 157 |
+
mel_hop_length = stft_cfg.get("hop_length", 160)
|
| 158 |
+
n_fft = stft_cfg.get("filter_length", 1024)
|
| 159 |
+
is_causal = stft_cfg.get("causal", True)
|
| 160 |
+
mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
|
| 161 |
+
|
| 162 |
+
return AudioEncoder(
|
| 163 |
+
ch=ddconfig.get("ch", 128),
|
| 164 |
+
ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
|
| 165 |
+
num_res_blocks=ddconfig.get("num_res_blocks", 2),
|
| 166 |
+
attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
|
| 167 |
+
resolution=ddconfig.get("resolution", 256),
|
| 168 |
+
z_channels=ddconfig.get("z_channels", 8),
|
| 169 |
+
double_z=ddconfig.get("double_z", True),
|
| 170 |
+
dropout=ddconfig.get("dropout", 0.0),
|
| 171 |
+
resamp_with_conv=ddconfig.get("resamp_with_conv", True),
|
| 172 |
+
in_channels=ddconfig.get("in_channels", 2),
|
| 173 |
+
attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
|
| 174 |
+
mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
|
| 175 |
+
norm_type=NormType(ddconfig.get("norm_type", "pixel")),
|
| 176 |
+
causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
|
| 177 |
+
sample_rate=sample_rate,
|
| 178 |
+
mel_hop_length=mel_hop_length,
|
| 179 |
+
n_fft=n_fft,
|
| 180 |
+
is_causal=is_causal,
|
| 181 |
+
mel_bins=mel_bins,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
|
| 186 |
+
SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
|
| 187 |
+
.with_matching(prefix="audio_vae.decoder.")
|
| 188 |
+
.with_matching(prefix="audio_vae.per_channel_statistics.")
|
| 189 |
+
.with_replacement("audio_vae.decoder.", "")
|
| 190 |
+
.with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
|
| 195 |
+
SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
|
| 196 |
+
.with_matching(prefix="audio_vae.encoder.")
|
| 197 |
+
.with_matching(prefix="audio_vae.per_channel_statistics.")
|
| 198 |
+
.with_replacement("audio_vae.encoder.", "")
|
| 199 |
+
.with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
| 200 |
+
)
|
ltx2/ltx_core/model/audio_vae/ops.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from ltx_core.types import Audio
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AudioProcessor(nn.Module):
|
| 9 |
+
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
target_sample_rate: int,
|
| 14 |
+
mel_bins: int,
|
| 15 |
+
mel_hop_length: int,
|
| 16 |
+
n_fft: int,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.target_sample_rate = target_sample_rate
|
| 20 |
+
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 21 |
+
sample_rate=target_sample_rate,
|
| 22 |
+
n_fft=n_fft,
|
| 23 |
+
win_length=n_fft,
|
| 24 |
+
hop_length=mel_hop_length,
|
| 25 |
+
f_min=0.0,
|
| 26 |
+
f_max=target_sample_rate / 2.0,
|
| 27 |
+
n_mels=mel_bins,
|
| 28 |
+
window_fn=torch.hann_window,
|
| 29 |
+
center=True,
|
| 30 |
+
pad_mode="reflect",
|
| 31 |
+
power=1.0,
|
| 32 |
+
mel_scale="slaney",
|
| 33 |
+
norm="slaney",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def resample_audio(self, audio: Audio) -> Audio:
|
| 37 |
+
"""Resample audio to the processor's target sample rate if needed."""
|
| 38 |
+
if audio.sampling_rate == self.target_sample_rate:
|
| 39 |
+
return audio
|
| 40 |
+
resampled = torchaudio.functional.resample(audio.waveform, audio.sampling_rate, self.target_sample_rate)
|
| 41 |
+
resampled = resampled.to(device=audio.waveform.device, dtype=audio.waveform.dtype)
|
| 42 |
+
return Audio(waveform=resampled, sampling_rate=self.target_sample_rate)
|
| 43 |
+
|
| 44 |
+
def waveform_to_mel(
|
| 45 |
+
self,
|
| 46 |
+
audio: Audio,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
| 49 |
+
waveform = self.resample_audio(audio).waveform
|
| 50 |
+
|
| 51 |
+
mel = self.mel_transform(waveform)
|
| 52 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 53 |
+
|
| 54 |
+
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
| 55 |
+
return mel.permute(0, 1, 3, 2).contiguous()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class PerChannelStatistics(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Per-channel statistics for normalizing and denormalizing the latent representation.
|
| 61 |
+
This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, latent_channels: int = 128) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
| 67 |
+
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
| 68 |
+
|
| 69 |
+
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
| 71 |
+
|
| 72 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
ltx2/ltx_core/model/audio_vae/resnet.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 8 |
+
|
| 9 |
+
LRELU_SLOPE = 0.1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResBlock1(torch.nn.Module):
|
| 13 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
|
| 14 |
+
super(ResBlock1, self).__init__()
|
| 15 |
+
self.convs1 = torch.nn.ModuleList(
|
| 16 |
+
[
|
| 17 |
+
torch.nn.Conv1d(
|
| 18 |
+
channels,
|
| 19 |
+
channels,
|
| 20 |
+
kernel_size,
|
| 21 |
+
1,
|
| 22 |
+
dilation=dilation[0],
|
| 23 |
+
padding="same",
|
| 24 |
+
),
|
| 25 |
+
torch.nn.Conv1d(
|
| 26 |
+
channels,
|
| 27 |
+
channels,
|
| 28 |
+
kernel_size,
|
| 29 |
+
1,
|
| 30 |
+
dilation=dilation[1],
|
| 31 |
+
padding="same",
|
| 32 |
+
),
|
| 33 |
+
torch.nn.Conv1d(
|
| 34 |
+
channels,
|
| 35 |
+
channels,
|
| 36 |
+
kernel_size,
|
| 37 |
+
1,
|
| 38 |
+
dilation=dilation[2],
|
| 39 |
+
padding="same",
|
| 40 |
+
),
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.convs2 = torch.nn.ModuleList(
|
| 45 |
+
[
|
| 46 |
+
torch.nn.Conv1d(
|
| 47 |
+
channels,
|
| 48 |
+
channels,
|
| 49 |
+
kernel_size,
|
| 50 |
+
1,
|
| 51 |
+
dilation=1,
|
| 52 |
+
padding="same",
|
| 53 |
+
),
|
| 54 |
+
torch.nn.Conv1d(
|
| 55 |
+
channels,
|
| 56 |
+
channels,
|
| 57 |
+
kernel_size,
|
| 58 |
+
1,
|
| 59 |
+
dilation=1,
|
| 60 |
+
padding="same",
|
| 61 |
+
),
|
| 62 |
+
torch.nn.Conv1d(
|
| 63 |
+
channels,
|
| 64 |
+
channels,
|
| 65 |
+
kernel_size,
|
| 66 |
+
1,
|
| 67 |
+
dilation=1,
|
| 68 |
+
padding="same",
|
| 69 |
+
),
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
|
| 75 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 76 |
+
xt = conv1(xt)
|
| 77 |
+
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
|
| 78 |
+
xt = conv2(xt)
|
| 79 |
+
x = xt + x
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ResBlock2(torch.nn.Module):
|
| 84 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
|
| 85 |
+
super(ResBlock2, self).__init__()
|
| 86 |
+
self.convs = torch.nn.ModuleList(
|
| 87 |
+
[
|
| 88 |
+
torch.nn.Conv1d(
|
| 89 |
+
channels,
|
| 90 |
+
channels,
|
| 91 |
+
kernel_size,
|
| 92 |
+
1,
|
| 93 |
+
dilation=dilation[0],
|
| 94 |
+
padding="same",
|
| 95 |
+
),
|
| 96 |
+
torch.nn.Conv1d(
|
| 97 |
+
channels,
|
| 98 |
+
channels,
|
| 99 |
+
kernel_size,
|
| 100 |
+
1,
|
| 101 |
+
dilation=dilation[1],
|
| 102 |
+
padding="same",
|
| 103 |
+
),
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
for conv in self.convs:
|
| 109 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 110 |
+
xt = conv(xt)
|
| 111 |
+
x = xt + x
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ResnetBlock(torch.nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
*,
|
| 119 |
+
in_channels: int,
|
| 120 |
+
out_channels: int | None = None,
|
| 121 |
+
conv_shortcut: bool = False,
|
| 122 |
+
dropout: float = 0.0,
|
| 123 |
+
temb_channels: int = 512,
|
| 124 |
+
norm_type: NormType = NormType.GROUP,
|
| 125 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.causality_axis = causality_axis
|
| 129 |
+
|
| 130 |
+
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
| 131 |
+
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
| 132 |
+
self.in_channels = in_channels
|
| 133 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 134 |
+
self.out_channels = out_channels
|
| 135 |
+
self.use_conv_shortcut = conv_shortcut
|
| 136 |
+
|
| 137 |
+
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
| 138 |
+
self.non_linearity = torch.nn.SiLU()
|
| 139 |
+
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 140 |
+
if temb_channels > 0:
|
| 141 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 142 |
+
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
| 143 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 144 |
+
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 145 |
+
if self.in_channels != self.out_channels:
|
| 146 |
+
if self.use_conv_shortcut:
|
| 147 |
+
self.conv_shortcut = make_conv2d(
|
| 148 |
+
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
self.nin_shortcut = make_conv2d(
|
| 152 |
+
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
temb: torch.Tensor | None = None,
|
| 159 |
+
) -> torch.Tensor:
|
| 160 |
+
h = x
|
| 161 |
+
h = self.norm1(h)
|
| 162 |
+
h = self.non_linearity(h)
|
| 163 |
+
h = self.conv1(h)
|
| 164 |
+
|
| 165 |
+
if temb is not None:
|
| 166 |
+
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
| 167 |
+
|
| 168 |
+
h = self.norm2(h)
|
| 169 |
+
h = self.non_linearity(h)
|
| 170 |
+
h = self.dropout(h)
|
| 171 |
+
h = self.conv2(h)
|
| 172 |
+
|
| 173 |
+
if self.in_channels != self.out_channels:
|
| 174 |
+
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
| 175 |
+
|
| 176 |
+
return x + h
|
ltx2/ltx_core/model/audio_vae/upsample.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 6 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 7 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 8 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 9 |
+
from ltx_core.model.common.normalization import NormType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Upsample(torch.nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_channels: int,
|
| 16 |
+
with_conv: bool,
|
| 17 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 18 |
+
) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.with_conv = with_conv
|
| 21 |
+
self.causality_axis = causality_axis
|
| 22 |
+
if self.with_conv:
|
| 23 |
+
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 27 |
+
if self.with_conv:
|
| 28 |
+
x = self.conv(x)
|
| 29 |
+
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
| 30 |
+
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
| 31 |
+
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
| 32 |
+
# So the output elements rely on the following windows:
|
| 33 |
+
# 0: [-,-,0]
|
| 34 |
+
# 1: [-,0,0]
|
| 35 |
+
# 2: [0,0,1]
|
| 36 |
+
# 3: [0,1,1]
|
| 37 |
+
# 4: [1,1,2]
|
| 38 |
+
# 5: [1,2,2]
|
| 39 |
+
# Notice that the first and second elements in the output rely only on the first element in the input,
|
| 40 |
+
# while all other elements rely on two elements in the input.
|
| 41 |
+
# So we can drop the first element to undo the padding (rather than the last element).
|
| 42 |
+
# This is a no-op for non-causal convolutions.
|
| 43 |
+
match self.causality_axis:
|
| 44 |
+
case CausalityAxis.NONE:
|
| 45 |
+
pass # x remains unchanged
|
| 46 |
+
case CausalityAxis.HEIGHT:
|
| 47 |
+
x = x[:, :, 1:, :]
|
| 48 |
+
case CausalityAxis.WIDTH:
|
| 49 |
+
x = x[:, :, :, 1:]
|
| 50 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 51 |
+
pass # x remains unchanged
|
| 52 |
+
case _:
|
| 53 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 54 |
+
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_upsampling_path( # noqa: PLR0913
|
| 59 |
+
*,
|
| 60 |
+
ch: int,
|
| 61 |
+
ch_mult: Tuple[int, ...],
|
| 62 |
+
num_resolutions: int,
|
| 63 |
+
num_res_blocks: int,
|
| 64 |
+
resolution: int,
|
| 65 |
+
temb_channels: int,
|
| 66 |
+
dropout: float,
|
| 67 |
+
norm_type: NormType,
|
| 68 |
+
causality_axis: CausalityAxis,
|
| 69 |
+
attn_type: AttentionType,
|
| 70 |
+
attn_resolutions: Set[int],
|
| 71 |
+
resamp_with_conv: bool,
|
| 72 |
+
initial_block_channels: int,
|
| 73 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 74 |
+
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
| 75 |
+
up_modules = torch.nn.ModuleList()
|
| 76 |
+
block_in = initial_block_channels
|
| 77 |
+
curr_res = resolution // (2 ** (num_resolutions - 1))
|
| 78 |
+
|
| 79 |
+
for level in reversed(range(num_resolutions)):
|
| 80 |
+
stage = torch.nn.Module()
|
| 81 |
+
stage.block = torch.nn.ModuleList()
|
| 82 |
+
stage.attn = torch.nn.ModuleList()
|
| 83 |
+
block_out = ch * ch_mult[level]
|
| 84 |
+
|
| 85 |
+
for _ in range(num_res_blocks + 1):
|
| 86 |
+
stage.block.append(
|
| 87 |
+
ResnetBlock(
|
| 88 |
+
in_channels=block_in,
|
| 89 |
+
out_channels=block_out,
|
| 90 |
+
temb_channels=temb_channels,
|
| 91 |
+
dropout=dropout,
|
| 92 |
+
norm_type=norm_type,
|
| 93 |
+
causality_axis=causality_axis,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
block_in = block_out
|
| 97 |
+
if curr_res in attn_resolutions:
|
| 98 |
+
stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 99 |
+
|
| 100 |
+
if level != 0:
|
| 101 |
+
stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 102 |
+
curr_res *= 2
|
| 103 |
+
|
| 104 |
+
up_modules.insert(0, stage)
|
| 105 |
+
|
| 106 |
+
return up_modules, block_in
|
ltx2/ltx_core/model/audio_vae/vocoder.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
| 18 |
+
# Adopted from https://github.com/NVIDIA/BigVGAN
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
return torch.where(
|
| 24 |
+
x == 0,
|
| 25 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 26 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
| 31 |
+
even = kernel_size % 2 == 0
|
| 32 |
+
half_size = kernel_size // 2
|
| 33 |
+
delta_f = 4 * half_width
|
| 34 |
+
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 35 |
+
if amplitude > 50.0:
|
| 36 |
+
beta = 0.1102 * (amplitude - 8.7)
|
| 37 |
+
elif amplitude >= 21.0:
|
| 38 |
+
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
| 39 |
+
else:
|
| 40 |
+
beta = 0.0
|
| 41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 42 |
+
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
| 43 |
+
if cutoff == 0:
|
| 44 |
+
filter_ = torch.zeros_like(time)
|
| 45 |
+
else:
|
| 46 |
+
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
| 47 |
+
filter_ /= filter_.sum()
|
| 48 |
+
return filter_.view(1, 1, kernel_size)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LowPassFilter1d(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
cutoff: float = 0.5,
|
| 55 |
+
half_width: float = 0.6,
|
| 56 |
+
stride: int = 1,
|
| 57 |
+
padding: bool = True,
|
| 58 |
+
padding_mode: str = "replicate",
|
| 59 |
+
kernel_size: int = 12,
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
if cutoff < -0.0:
|
| 63 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 64 |
+
if cutoff > 0.5:
|
| 65 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 66 |
+
self.kernel_size = kernel_size
|
| 67 |
+
self.even = kernel_size % 2 == 0
|
| 68 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 69 |
+
self.pad_right = kernel_size // 2
|
| 70 |
+
self.stride = stride
|
| 71 |
+
self.padding = padding
|
| 72 |
+
self.padding_mode = padding_mode
|
| 73 |
+
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
_, n_channels, _ = x.shape
|
| 77 |
+
if self.padding:
|
| 78 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 79 |
+
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class UpSample1d(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
ratio: int = 2,
|
| 86 |
+
kernel_size: int | None = None,
|
| 87 |
+
persistent: bool = True,
|
| 88 |
+
window_type: str = "kaiser",
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.ratio = ratio
|
| 92 |
+
self.stride = ratio
|
| 93 |
+
|
| 94 |
+
if window_type == "hann":
|
| 95 |
+
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
| 96 |
+
rolloff = 0.99
|
| 97 |
+
lowpass_filter_width = 6
|
| 98 |
+
width = math.ceil(lowpass_filter_width / rolloff)
|
| 99 |
+
self.kernel_size = 2 * width * ratio + 1
|
| 100 |
+
self.pad = width
|
| 101 |
+
self.pad_left = 2 * width * ratio
|
| 102 |
+
self.pad_right = self.kernel_size - ratio
|
| 103 |
+
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
| 104 |
+
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
| 105 |
+
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
| 106 |
+
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
| 107 |
+
else:
|
| 108 |
+
# Kaiser-windowed sinc filter (BigVGAN default).
|
| 109 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 110 |
+
self.pad = self.kernel_size // ratio - 1
|
| 111 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 112 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 113 |
+
sinc_filter = kaiser_sinc_filter1d(
|
| 114 |
+
cutoff=0.5 / ratio,
|
| 115 |
+
half_width=0.6 / ratio,
|
| 116 |
+
kernel_size=self.kernel_size,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
_, n_channels, _ = x.shape
|
| 123 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 124 |
+
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
| 125 |
+
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
| 126 |
+
return x[..., self.pad_left : -self.pad_right]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class DownSample1d(nn.Module):
|
| 130 |
+
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.ratio = ratio
|
| 133 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 134 |
+
self.lowpass = LowPassFilter1d(
|
| 135 |
+
cutoff=0.5 / ratio,
|
| 136 |
+
half_width=0.6 / ratio,
|
| 137 |
+
stride=ratio,
|
| 138 |
+
kernel_size=self.kernel_size,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 142 |
+
return self.lowpass(x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Activation1d(nn.Module):
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
activation: nn.Module,
|
| 149 |
+
up_ratio: int = 2,
|
| 150 |
+
down_ratio: int = 2,
|
| 151 |
+
up_kernel_size: int = 12,
|
| 152 |
+
down_kernel_size: int = 12,
|
| 153 |
+
) -> None:
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.act = activation
|
| 156 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 157 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
x = self.upsample(x)
|
| 161 |
+
x = self.act(x)
|
| 162 |
+
return self.downsample(x)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Snake(nn.Module):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
in_features: int,
|
| 169 |
+
alpha: float = 1.0,
|
| 170 |
+
alpha_trainable: bool = True,
|
| 171 |
+
alpha_logscale: bool = True,
|
| 172 |
+
) -> None:
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.alpha_logscale = alpha_logscale
|
| 175 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 176 |
+
self.alpha.requires_grad = alpha_trainable
|
| 177 |
+
self.eps = 1e-9
|
| 178 |
+
|
| 179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 181 |
+
if self.alpha_logscale:
|
| 182 |
+
alpha = torch.exp(alpha)
|
| 183 |
+
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class SnakeBeta(nn.Module):
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
in_features: int,
|
| 190 |
+
alpha: float = 1.0,
|
| 191 |
+
alpha_trainable: bool = True,
|
| 192 |
+
alpha_logscale: bool = True,
|
| 193 |
+
) -> None:
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.alpha_logscale = alpha_logscale
|
| 196 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 197 |
+
self.alpha.requires_grad = alpha_trainable
|
| 198 |
+
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 199 |
+
self.beta.requires_grad = alpha_trainable
|
| 200 |
+
self.eps = 1e-9
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 204 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 205 |
+
if self.alpha_logscale:
|
| 206 |
+
alpha = torch.exp(alpha)
|
| 207 |
+
beta = torch.exp(beta)
|
| 208 |
+
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class AMPBlock1(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
channels: int,
|
| 215 |
+
kernel_size: int = 3,
|
| 216 |
+
dilation: tuple[int, int, int] = (1, 3, 5),
|
| 217 |
+
activation: str = "snake",
|
| 218 |
+
) -> None:
|
| 219 |
+
super().__init__()
|
| 220 |
+
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
| 221 |
+
self.convs1 = nn.ModuleList(
|
| 222 |
+
[
|
| 223 |
+
nn.Conv1d(
|
| 224 |
+
channels,
|
| 225 |
+
channels,
|
| 226 |
+
kernel_size,
|
| 227 |
+
1,
|
| 228 |
+
dilation=dilation[0],
|
| 229 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 230 |
+
),
|
| 231 |
+
nn.Conv1d(
|
| 232 |
+
channels,
|
| 233 |
+
channels,
|
| 234 |
+
kernel_size,
|
| 235 |
+
1,
|
| 236 |
+
dilation=dilation[1],
|
| 237 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 238 |
+
),
|
| 239 |
+
nn.Conv1d(
|
| 240 |
+
channels,
|
| 241 |
+
channels,
|
| 242 |
+
kernel_size,
|
| 243 |
+
1,
|
| 244 |
+
dilation=dilation[2],
|
| 245 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 246 |
+
),
|
| 247 |
+
]
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
self.convs2 = nn.ModuleList(
|
| 251 |
+
[
|
| 252 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 253 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 254 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 255 |
+
]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
| 259 |
+
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
| 260 |
+
|
| 261 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 262 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
| 263 |
+
xt = a1(x)
|
| 264 |
+
xt = c1(xt)
|
| 265 |
+
xt = a2(xt)
|
| 266 |
+
xt = c2(xt)
|
| 267 |
+
x = x + xt
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Vocoder(torch.nn.Module):
|
| 272 |
+
"""
|
| 273 |
+
Vocoder model for synthesizing audio from Mel spectrograms.
|
| 274 |
+
Args:
|
| 275 |
+
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
| 276 |
+
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
| 277 |
+
upsample_rates: List of upsampling rates.
|
| 278 |
+
This value is read from the checkpoint at `config.vocoder.upsample_rates`.
|
| 279 |
+
upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
|
| 280 |
+
This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
|
| 281 |
+
resblock_dilation_sizes: List of dilation sizes for the residual blocks.
|
| 282 |
+
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
| 283 |
+
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
| 284 |
+
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
| 285 |
+
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
| 286 |
+
This value is read from the checkpoint at `config.vocoder.resblock`.
|
| 287 |
+
output_sampling_rate: Waveform sample rate.
|
| 288 |
+
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
| 289 |
+
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
| 290 |
+
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
| 291 |
+
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
| 292 |
+
use_bias_at_final: Whether to use bias in the final conv layer.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__( # noqa: PLR0913
|
| 296 |
+
self,
|
| 297 |
+
resblock_kernel_sizes: List[int] | None = None,
|
| 298 |
+
upsample_rates: List[int] | None = None,
|
| 299 |
+
upsample_kernel_sizes: List[int] | None = None,
|
| 300 |
+
resblock_dilation_sizes: List[List[int]] | None = None,
|
| 301 |
+
upsample_initial_channel: int = 1024,
|
| 302 |
+
resblock: str = "1",
|
| 303 |
+
output_sampling_rate: int = 24000,
|
| 304 |
+
activation: str = "snake",
|
| 305 |
+
use_tanh_at_final: bool = True,
|
| 306 |
+
apply_final_activation: bool = True,
|
| 307 |
+
use_bias_at_final: bool = True,
|
| 308 |
+
) -> None:
|
| 309 |
+
super().__init__()
|
| 310 |
+
|
| 311 |
+
# Mutable default values are not supported as default arguments.
|
| 312 |
+
if resblock_kernel_sizes is None:
|
| 313 |
+
resblock_kernel_sizes = [3, 7, 11]
|
| 314 |
+
if upsample_rates is None:
|
| 315 |
+
upsample_rates = [6, 5, 2, 2, 2]
|
| 316 |
+
if upsample_kernel_sizes is None:
|
| 317 |
+
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
| 318 |
+
if resblock_dilation_sizes is None:
|
| 319 |
+
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 320 |
+
|
| 321 |
+
self.output_sampling_rate = output_sampling_rate
|
| 322 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 323 |
+
self.num_upsamples = len(upsample_rates)
|
| 324 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 325 |
+
self.apply_final_activation = apply_final_activation
|
| 326 |
+
self.is_amp = resblock == "AMP1"
|
| 327 |
+
|
| 328 |
+
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
| 329 |
+
# bins each), 2 output channels.
|
| 330 |
+
self.conv_pre = nn.Conv1d(
|
| 331 |
+
in_channels=128,
|
| 332 |
+
out_channels=upsample_initial_channel,
|
| 333 |
+
kernel_size=7,
|
| 334 |
+
stride=1,
|
| 335 |
+
padding=3,
|
| 336 |
+
)
|
| 337 |
+
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
| 338 |
+
|
| 339 |
+
self.ups = nn.ModuleList(
|
| 340 |
+
nn.ConvTranspose1d(
|
| 341 |
+
upsample_initial_channel // (2**i),
|
| 342 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 343 |
+
kernel_size,
|
| 344 |
+
stride,
|
| 345 |
+
padding=(kernel_size - stride) // 2,
|
| 346 |
+
)
|
| 347 |
+
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
| 351 |
+
self.resblocks = nn.ModuleList()
|
| 352 |
+
|
| 353 |
+
for i in range(len(upsample_rates)):
|
| 354 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 355 |
+
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
| 356 |
+
if self.is_amp:
|
| 357 |
+
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
| 358 |
+
else:
|
| 359 |
+
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
| 360 |
+
|
| 361 |
+
if self.is_amp:
|
| 362 |
+
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
| 363 |
+
else:
|
| 364 |
+
self.act_post = nn.LeakyReLU()
|
| 365 |
+
|
| 366 |
+
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
| 367 |
+
self.conv_post = nn.Conv1d(
|
| 368 |
+
in_channels=final_channels,
|
| 369 |
+
out_channels=2,
|
| 370 |
+
kernel_size=7,
|
| 371 |
+
stride=1,
|
| 372 |
+
padding=3,
|
| 373 |
+
bias=use_bias_at_final,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 377 |
+
"""
|
| 378 |
+
Forward pass of the vocoder.
|
| 379 |
+
Args:
|
| 380 |
+
x: Input Mel spectrogram tensor. Can be either:
|
| 381 |
+
- 3D: (batch_size, time, mel_bins) for mono
|
| 382 |
+
- 4D: (batch_size, 2, time, mel_bins) for stereo
|
| 383 |
+
Returns:
|
| 384 |
+
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
| 385 |
+
"""
|
| 386 |
+
x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
|
| 387 |
+
|
| 388 |
+
if x.dim() == 4: # stereo
|
| 389 |
+
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
| 390 |
+
x = einops.rearrange(x, "b s c t -> b (s c) t")
|
| 391 |
+
|
| 392 |
+
x = self.conv_pre(x)
|
| 393 |
+
|
| 394 |
+
for i in range(self.num_upsamples):
|
| 395 |
+
if not self.is_amp:
|
| 396 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 397 |
+
x = self.ups[i](x)
|
| 398 |
+
start = i * self.num_kernels
|
| 399 |
+
end = start + self.num_kernels
|
| 400 |
+
|
| 401 |
+
# Evaluate all resblocks with the same input tensor so they can run
|
| 402 |
+
# independently (and thus in parallel on accelerator hardware) before
|
| 403 |
+
# aggregating their outputs via mean.
|
| 404 |
+
block_outputs = torch.stack(
|
| 405 |
+
[self.resblocks[idx](x) for idx in range(start, end)],
|
| 406 |
+
dim=0,
|
| 407 |
+
)
|
| 408 |
+
x = block_outputs.mean(dim=0)
|
| 409 |
+
|
| 410 |
+
x = self.act_post(x)
|
| 411 |
+
x = self.conv_post(x)
|
| 412 |
+
|
| 413 |
+
if self.apply_final_activation:
|
| 414 |
+
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
| 415 |
+
|
| 416 |
+
return x
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class _STFTFn(nn.Module):
|
| 420 |
+
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
| 421 |
+
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
| 422 |
+
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
| 423 |
+
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
| 424 |
+
bit-identical to what it was trained on.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.hop_length = hop_length
|
| 430 |
+
self.win_length = win_length
|
| 431 |
+
n_freqs = filter_length // 2 + 1
|
| 432 |
+
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
| 433 |
+
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
| 434 |
+
|
| 435 |
+
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 436 |
+
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
| 437 |
+
Applies causal (left-only) padding of win_length - hop_length samples so that
|
| 438 |
+
each output frame depends only on past and present input — no lookahead.
|
| 439 |
+
Args:
|
| 440 |
+
y: Waveform tensor of shape (B, T).
|
| 441 |
+
Returns:
|
| 442 |
+
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
| 443 |
+
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
| 444 |
+
"""
|
| 445 |
+
if y.dim() == 2:
|
| 446 |
+
y = y.unsqueeze(1) # (B, 1, T)
|
| 447 |
+
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
| 448 |
+
y = F.pad(y, (left_pad, 0))
|
| 449 |
+
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
| 450 |
+
n_freqs = spec.shape[1] // 2
|
| 451 |
+
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
| 452 |
+
magnitude = torch.sqrt(real**2 + imag**2)
|
| 453 |
+
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
| 454 |
+
return magnitude, phase
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class MelSTFT(nn.Module):
|
| 458 |
+
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
| 459 |
+
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
| 460 |
+
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
| 461 |
+
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
| 462 |
+
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
def __init__(
|
| 466 |
+
self,
|
| 467 |
+
filter_length: int,
|
| 468 |
+
hop_length: int,
|
| 469 |
+
win_length: int,
|
| 470 |
+
n_mel_channels: int,
|
| 471 |
+
) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
| 474 |
+
|
| 475 |
+
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
| 476 |
+
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
| 477 |
+
n_freqs = filter_length // 2 + 1
|
| 478 |
+
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
| 479 |
+
|
| 480 |
+
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 481 |
+
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
| 482 |
+
Args:
|
| 483 |
+
y: Waveform tensor of shape (B, T).
|
| 484 |
+
Returns:
|
| 485 |
+
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
| 486 |
+
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
| 487 |
+
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
| 488 |
+
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
| 489 |
+
"""
|
| 490 |
+
magnitude, phase = self.stft_fn(y)
|
| 491 |
+
energy = torch.norm(magnitude, dim=1)
|
| 492 |
+
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
| 493 |
+
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 494 |
+
return log_mel, magnitude, phase, energy
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class VocoderWithBWE(nn.Module):
|
| 498 |
+
"""Vocoder with bandwidth extension (BWE) upsampling.
|
| 499 |
+
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
| 500 |
+
to a higher sample rate. The BWE computes a mel spectrogram from the
|
| 501 |
+
vocoder output, runs it through a second generator to predict a residual,
|
| 502 |
+
and adds it to a sinc-resampled skip connection.
|
| 503 |
+
The forward pass runs in fp32 via autocast to avoid bfloat16 accumulation
|
| 504 |
+
errors that degrade spectral metrics by 40-90%.
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __init__(
|
| 508 |
+
self,
|
| 509 |
+
vocoder: Vocoder,
|
| 510 |
+
bwe_generator: Vocoder,
|
| 511 |
+
mel_stft: MelSTFT,
|
| 512 |
+
input_sampling_rate: int,
|
| 513 |
+
output_sampling_rate: int,
|
| 514 |
+
hop_length: int,
|
| 515 |
+
) -> None:
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.vocoder = vocoder
|
| 518 |
+
self.bwe_generator = bwe_generator
|
| 519 |
+
self.mel_stft = mel_stft
|
| 520 |
+
self.input_sampling_rate = input_sampling_rate
|
| 521 |
+
self.output_sampling_rate = output_sampling_rate
|
| 522 |
+
self.hop_length = hop_length
|
| 523 |
+
# Compute the resampler on CPU so the sinc filter is materialized even when
|
| 524 |
+
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
| 525 |
+
# The filter is not stored in the checkpoint (persistent=False).
|
| 526 |
+
with torch.device("cpu"):
|
| 527 |
+
self.resampler = UpSample1d(
|
| 528 |
+
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
@property
|
| 532 |
+
def conv_pre(self) -> nn.Conv1d:
|
| 533 |
+
return self.vocoder.conv_pre
|
| 534 |
+
|
| 535 |
+
@property
|
| 536 |
+
def conv_post(self) -> nn.Conv1d:
|
| 537 |
+
return self.vocoder.conv_post
|
| 538 |
+
|
| 539 |
+
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
| 540 |
+
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
| 541 |
+
Args:
|
| 542 |
+
audio: Waveform tensor of shape (B, C, T).
|
| 543 |
+
Returns:
|
| 544 |
+
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
| 545 |
+
"""
|
| 546 |
+
batch, n_channels, _ = audio.shape
|
| 547 |
+
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
| 548 |
+
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
| 549 |
+
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
| 550 |
+
|
| 551 |
+
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
| 552 |
+
"""Run the full vocoder + BWE forward pass.
|
| 553 |
+
Runs in float32 regardless of weight or input dtype. bfloat16 arithmetic
|
| 554 |
+
causes 40-90% spectral metric degradation due to accumulation errors
|
| 555 |
+
compounding through 108 sequential convolutions in the BigVGAN v2 architecture.
|
| 556 |
+
Args:
|
| 557 |
+
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
| 558 |
+
or (B, T, mel_bins) for mono. Same format as Vocoder.forward.
|
| 559 |
+
Returns:
|
| 560 |
+
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
| 561 |
+
"""
|
| 562 |
+
input_dtype = mel_spec.dtype
|
| 563 |
+
# Run the entire forward pass in fp32. bfloat16 accumulation errors
|
| 564 |
+
# compound through 108 sequential convolutions and degrade spectral
|
| 565 |
+
# metrics (mel_l1, MRSTFT) by 40-90% while perceptual quality (CDPAM)
|
| 566 |
+
# is unaffected. fp32 eliminates this degradation.
|
| 567 |
+
# We use autocast(dtype=float32) rather than self.float() because it
|
| 568 |
+
# upcasts bf16 weights per-op at kernel level, avoiding the temporary
|
| 569 |
+
# memory spike of self.float() / self.to(original_dtype).
|
| 570 |
+
# Benchmarked on H100 (128.5M-param model):
|
| 571 |
+
# autocast fp32: +70 MB peak VRAM, 123 ms (vs 482 MB / 95 ms for bf16)
|
| 572 |
+
# model.float(): +324 MB peak VRAM, 149 ms
|
| 573 |
+
# Tested: both approaches produce bit-identical output.
|
| 574 |
+
|
| 575 |
+
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
|
| 576 |
+
x = self.vocoder(mel_spec.float())
|
| 577 |
+
_, _, length_low_rate = x.shape
|
| 578 |
+
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
| 579 |
+
|
| 580 |
+
# Pad to multiple of hop_length for exact mel frame count
|
| 581 |
+
remainder = length_low_rate % self.hop_length
|
| 582 |
+
if remainder != 0:
|
| 583 |
+
x = F.pad(x, (0, self.hop_length - remainder))
|
| 584 |
+
|
| 585 |
+
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
| 586 |
+
mel = self._compute_mel(x)
|
| 587 |
+
|
| 588 |
+
# Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
| 589 |
+
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
| 590 |
+
residual = self.bwe_generator(mel_for_bwe)
|
| 591 |
+
skip = self.resampler(x)
|
| 592 |
+
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
| 593 |
+
|
| 594 |
+
return torch.clamp(residual + skip, -1, 1)[..., :output_length].to(input_dtype)
|
ltx2/ltx_core/model/common/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Common model utilities."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"NormType",
|
| 7 |
+
"PixelNorm",
|
| 8 |
+
"build_normalization_layer",
|
| 9 |
+
]
|