Add files using upload-large-folder tool
Browse files- packages/ltx-core/pyproject.toml +55 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +202 -0
- packages/ltx-trainer/configs/accelerate/ddp.yaml +16 -0
- packages/ltx-trainer/configs/accelerate/ddp_compile.yaml +21 -0
- packages/ltx-trainer/configs/accelerate/fsdp.yaml +29 -0
- packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml +34 -0
- packages/ltx-trainer/configs/ltx2_av_lora.yaml +313 -0
- packages/ltx-trainer/configs/ltx2_av_lora_low_vram.yaml +325 -0
- packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml +329 -0
- packages/ltx-trainer/docs/configuration-reference.md +372 -0
- packages/ltx-trainer/docs/custom-training-strategies.md +510 -0
- packages/ltx-trainer/docs/dataset-preparation.md +342 -0
- packages/ltx-trainer/docs/quick-start.md +130 -0
- packages/ltx-trainer/docs/training-guide.md +203 -0
- packages/ltx-trainer/docs/training-modes.md +277 -0
- packages/ltx-trainer/docs/troubleshooting.md +300 -0
- packages/ltx-trainer/docs/utility-scripts.md +274 -0
- packages/ltx-trainer/scripts/caption_videos.py +486 -0
- packages/ltx-trainer/scripts/compute_reference.py +288 -0
- packages/ltx-trainer/scripts/decode_latents.py +369 -0
- packages/ltx-trainer/scripts/process_captions.py +435 -0
- packages/ltx-trainer/scripts/process_dataset.py +317 -0
- packages/ltx-trainer/scripts/process_videos.py +1039 -0
- packages/ltx-trainer/scripts/split_scenes.py +417 -0
- packages/ltx-trainer/scripts/train.py +64 -0
- packages/ltx-trainer/src/ltx_trainer/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-trainer/src/ltx_trainer/__pycache__/model_loader.cpython-312.pyc +0 -0
- packages/ltx-trainer/src/ltx_trainer/captioning.py +401 -0
- packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py +85 -0
- packages/ltx-trainer/src/ltx_trainer/gpu_utils.py +90 -0
- packages/ltx-trainer/src/ltx_trainer/progress.py +236 -0
- packages/ltx-trainer/src/ltx_trainer/quantization.py +195 -0
- packages/ltx-trainer/src/ltx_trainer/trainer.py +1000 -0
- packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py +58 -0
- packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py +262 -0
- packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py +291 -0
- packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py +303 -0
- packages/ltx-trainer/src/ltx_trainer/utils.py +88 -0
- packages/ltx-trainer/templates/model_card.md +59 -0
packages/ltx-core/pyproject.toml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ltx-core"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "Core implementation of Lightricks' LTX-2 model"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"torch~=2.7",
|
| 9 |
+
"torchaudio",
|
| 10 |
+
"einops",
|
| 11 |
+
"numpy",
|
| 12 |
+
"transformers>=4.52",
|
| 13 |
+
"safetensors",
|
| 14 |
+
"accelerate",
|
| 15 |
+
"scipy>=1.14",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[project.optional-dependencies]
|
| 19 |
+
xformers = ["xformers"]
|
| 20 |
+
fp8-trtllm = [
|
| 21 |
+
"tensorrt-llm==1.0.0",
|
| 22 |
+
"onnx>=1.16.0,<1.20.0",
|
| 23 |
+
"openmpi",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[tool.uv]
|
| 27 |
+
conflicts = [
|
| 28 |
+
[
|
| 29 |
+
{ extra = "xformers" },
|
| 30 |
+
{ extra = "fp8-trtllm" },
|
| 31 |
+
],
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[tool.uv.sources]
|
| 35 |
+
xformers = { index = "pytorch" }
|
| 36 |
+
tensorrt-llm = { index = "nvidia" }
|
| 37 |
+
|
| 38 |
+
[[tool.uv.index]]
|
| 39 |
+
name = "pytorch"
|
| 40 |
+
url = "https://download.pytorch.org/whl/cu129"
|
| 41 |
+
explicit = true
|
| 42 |
+
|
| 43 |
+
[[tool.uv.index]]
|
| 44 |
+
name = "nvidia"
|
| 45 |
+
url = "https://pypi.nvidia.com/"
|
| 46 |
+
explicit = true
|
| 47 |
+
|
| 48 |
+
[build-system]
|
| 49 |
+
requires = ["uv_build>=0.9.8,<0.10.0"]
|
| 50 |
+
build-backend = "uv_build"
|
| 51 |
+
|
| 52 |
+
[dependency-groups]
|
| 53 |
+
dev = [
|
| 54 |
+
"scikit-image>=0.25.2",
|
| 55 |
+
]
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (624 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-312.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-312.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-312.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-312.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-312.pyc
ADDED
|
Binary file (9.1 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-312.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-312.pyc
ADDED
|
Binary file (7.2 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-312.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 8 |
+
from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
|
| 9 |
+
from ltx_core.utils import find_matching_file
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GemmaTextEncoder(torch.nn.Module):
|
| 13 |
+
"""Pure Gemma text encoder — runs the LLM and returns raw hidden states.
|
| 14 |
+
Prompt enhancement (generate) is also supported since the full
|
| 15 |
+
Gemma3ForConditionalGeneration model (including lm_head) is loaded.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
model: Gemma3ForConditionalGeneration | None = None,
|
| 21 |
+
tokenizer: LTXVGemmaTokenizer | None = None,
|
| 22 |
+
processor: Gemma3Processor | None = None,
|
| 23 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.model = model
|
| 27 |
+
self.tokenizer = tokenizer
|
| 28 |
+
self.processor = processor
|
| 29 |
+
self._dtype = dtype
|
| 30 |
+
|
| 31 |
+
def encode(
|
| 32 |
+
self,
|
| 33 |
+
text: str,
|
| 34 |
+
padding_side: str = "left", # noqa: ARG002
|
| 35 |
+
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
| 36 |
+
"""Run Gemma LLM and return raw hidden states + attention mask.
|
| 37 |
+
Calls the inner model (self.model.model) to skip lm_head logits computation (~500 MiB saving).
|
| 38 |
+
Returns:
|
| 39 |
+
(hidden_states, attention_mask) where hidden_states is a tuple of per-layer tensors.
|
| 40 |
+
"""
|
| 41 |
+
token_pairs = self.tokenizer.tokenize_with_weights(text)["gemma"]
|
| 42 |
+
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=self.model.device)
|
| 43 |
+
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=self.model.device)
|
| 44 |
+
outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
| 45 |
+
hidden_states = outputs.hidden_states
|
| 46 |
+
del outputs
|
| 47 |
+
return hidden_states, attention_mask
|
| 48 |
+
|
| 49 |
+
# --- Prompt enhancement methods ---
|
| 50 |
+
|
| 51 |
+
def _enhance(
|
| 52 |
+
self,
|
| 53 |
+
messages: list[dict[str, str]],
|
| 54 |
+
image: torch.Tensor | None = None,
|
| 55 |
+
max_new_tokens: int = 512,
|
| 56 |
+
seed: int = 10,
|
| 57 |
+
) -> str:
|
| 58 |
+
text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 59 |
+
|
| 60 |
+
model_inputs = self.processor(
|
| 61 |
+
text=text,
|
| 62 |
+
images=image,
|
| 63 |
+
return_tensors="pt",
|
| 64 |
+
).to(self.model.device)
|
| 65 |
+
pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0
|
| 66 |
+
model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id)
|
| 67 |
+
|
| 68 |
+
with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]):
|
| 69 |
+
torch.manual_seed(seed)
|
| 70 |
+
outputs = self.model.generate(
|
| 71 |
+
**model_inputs,
|
| 72 |
+
max_new_tokens=max_new_tokens,
|
| 73 |
+
do_sample=True,
|
| 74 |
+
temperature=0.7,
|
| 75 |
+
)
|
| 76 |
+
generated_ids = outputs[0][len(model_inputs.input_ids[0]) :]
|
| 77 |
+
enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 78 |
+
|
| 79 |
+
return enhanced_prompt
|
| 80 |
+
|
| 81 |
+
def enhance_t2v(
|
| 82 |
+
self,
|
| 83 |
+
prompt: str,
|
| 84 |
+
max_new_tokens: int = 512,
|
| 85 |
+
system_prompt: str | None = None,
|
| 86 |
+
seed: int = 10,
|
| 87 |
+
) -> str:
|
| 88 |
+
"""Enhance a text prompt for T2V generation."""
|
| 89 |
+
system_prompt = system_prompt or self.default_gemma_t2v_system_prompt
|
| 90 |
+
|
| 91 |
+
messages = [
|
| 92 |
+
{"role": "system", "content": system_prompt},
|
| 93 |
+
{"role": "user", "content": f"user prompt: {prompt}"},
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
return self._enhance(messages, max_new_tokens=max_new_tokens, seed=seed)
|
| 97 |
+
|
| 98 |
+
def enhance_i2v(
|
| 99 |
+
self,
|
| 100 |
+
prompt: str,
|
| 101 |
+
image: torch.Tensor,
|
| 102 |
+
max_new_tokens: int = 512,
|
| 103 |
+
system_prompt: str | None = None,
|
| 104 |
+
seed: int = 10,
|
| 105 |
+
) -> str:
|
| 106 |
+
"""Enhance a text prompt for I2V generation using a reference image."""
|
| 107 |
+
system_prompt = system_prompt or self.default_gemma_i2v_system_prompt
|
| 108 |
+
messages = [
|
| 109 |
+
{"role": "system", "content": system_prompt},
|
| 110 |
+
{
|
| 111 |
+
"role": "user",
|
| 112 |
+
"content": [
|
| 113 |
+
{"type": "image"},
|
| 114 |
+
{"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
|
| 115 |
+
],
|
| 116 |
+
},
|
| 117 |
+
]
|
| 118 |
+
return self._enhance(messages, image=image, max_new_tokens=max_new_tokens, seed=seed)
|
| 119 |
+
|
| 120 |
+
@functools.cached_property
|
| 121 |
+
def default_gemma_i2v_system_prompt(self) -> str:
|
| 122 |
+
return _load_system_prompt("gemma_i2v_system_prompt.txt")
|
| 123 |
+
|
| 124 |
+
@functools.cached_property
|
| 125 |
+
def default_gemma_t2v_system_prompt(self) -> str:
|
| 126 |
+
return _load_system_prompt("gemma_t2v_system_prompt.txt")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# --- Standalone utility functions ---
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@functools.lru_cache(maxsize=2)
|
| 133 |
+
def _load_system_prompt(prompt_name: str) -> str:
|
| 134 |
+
with open(Path(__file__).parent / "prompts" / f"{prompt_name}", "r") as f:
|
| 135 |
+
return f.read()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _cat_with_padding(
|
| 139 |
+
tensor: torch.Tensor,
|
| 140 |
+
padding_length: int,
|
| 141 |
+
value: int | float,
|
| 142 |
+
) -> torch.Tensor:
|
| 143 |
+
"""Concatenate a tensor with a padding tensor of the given value."""
|
| 144 |
+
return torch.cat(
|
| 145 |
+
[
|
| 146 |
+
tensor,
|
| 147 |
+
torch.full(
|
| 148 |
+
(1, padding_length),
|
| 149 |
+
value,
|
| 150 |
+
dtype=tensor.dtype,
|
| 151 |
+
device=tensor.device,
|
| 152 |
+
),
|
| 153 |
+
],
|
| 154 |
+
dim=1,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _pad_inputs_for_attention_alignment(
|
| 159 |
+
model_inputs: dict[str, torch.Tensor],
|
| 160 |
+
pad_token_id: int = 0,
|
| 161 |
+
alignment: int = 8,
|
| 162 |
+
) -> dict[str, torch.Tensor]:
|
| 163 |
+
"""Pad sequence length to multiple of alignment for Flash Attention compatibility."""
|
| 164 |
+
seq_len = model_inputs.input_ids.shape[1]
|
| 165 |
+
padded_len = ((seq_len + alignment - 1) // alignment) * alignment
|
| 166 |
+
padding_length = padded_len - seq_len
|
| 167 |
+
|
| 168 |
+
if padding_length > 0:
|
| 169 |
+
model_inputs["input_ids"] = _cat_with_padding(model_inputs.input_ids, padding_length, pad_token_id)
|
| 170 |
+
model_inputs["attention_mask"] = _cat_with_padding(model_inputs.attention_mask, padding_length, 0)
|
| 171 |
+
if "token_type_ids" in model_inputs and model_inputs["token_type_ids"] is not None:
|
| 172 |
+
model_inputs["token_type_ids"] = _cat_with_padding(model_inputs["token_type_ids"], padding_length, 0)
|
| 173 |
+
|
| 174 |
+
return model_inputs
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def module_ops_from_gemma_root(gemma_root: str) -> tuple[ModuleOps, ...]:
|
| 178 |
+
tokenizer_root = str(find_matching_file(gemma_root, "tokenizer.model").parent)
|
| 179 |
+
processor_root = str(find_matching_file(gemma_root, "preprocessor_config.json").parent)
|
| 180 |
+
|
| 181 |
+
def load_tokenizer(module: GemmaTextEncoder) -> GemmaTextEncoder:
|
| 182 |
+
module.tokenizer = LTXVGemmaTokenizer(tokenizer_root, 1024)
|
| 183 |
+
return module
|
| 184 |
+
|
| 185 |
+
def load_processor(module: GemmaTextEncoder) -> GemmaTextEncoder:
|
| 186 |
+
image_processor = AutoImageProcessor.from_pretrained(processor_root, local_files_only=True)
|
| 187 |
+
if not module.tokenizer:
|
| 188 |
+
raise ValueError("Tokenizer model operation must be performed before processor model operation")
|
| 189 |
+
module.processor = Gemma3Processor(image_processor=image_processor, tokenizer=module.tokenizer.tokenizer)
|
| 190 |
+
return module
|
| 191 |
+
|
| 192 |
+
tokenizer_load_ops = ModuleOps(
|
| 193 |
+
"TokenizerLoad",
|
| 194 |
+
matcher=lambda module: isinstance(module, GemmaTextEncoder) and module.tokenizer is None,
|
| 195 |
+
mutator=load_tokenizer,
|
| 196 |
+
)
|
| 197 |
+
processor_load_ops = ModuleOps(
|
| 198 |
+
"ProcessorLoad",
|
| 199 |
+
matcher=lambda module: isinstance(module, GemmaTextEncoder) and module.processor is None,
|
| 200 |
+
mutator=load_processor,
|
| 201 |
+
)
|
| 202 |
+
return (tokenizer_load_ops, processor_load_ops)
|
packages/ltx-trainer/configs/accelerate/ddp.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: bf16
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 4
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
packages/ltx-trainer/configs/accelerate/ddp_compile.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
dynamo_config:
|
| 3 |
+
dynamo_backend: INDUCTOR
|
| 4 |
+
dynamo_mode: default
|
| 5 |
+
dynamo_use_fullgraph: false
|
| 6 |
+
dynamo_use_dynamic: true
|
| 7 |
+
debug: false
|
| 8 |
+
distributed_type: MULTI_GPU
|
| 9 |
+
downcast_bf16: 'no'
|
| 10 |
+
enable_cpu_affinity: false
|
| 11 |
+
machine_rank: 0
|
| 12 |
+
main_training_function: main
|
| 13 |
+
mixed_precision: bf16
|
| 14 |
+
num_machines: 1
|
| 15 |
+
num_processes: 4
|
| 16 |
+
rdzv_backend: static
|
| 17 |
+
same_network: true
|
| 18 |
+
tpu_env: [ ]
|
| 19 |
+
tpu_use_cluster: false
|
| 20 |
+
tpu_use_sudo: false
|
| 21 |
+
use_cpu: false
|
packages/ltx-trainer/configs/accelerate/fsdp.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: FSDP
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
fsdp_config:
|
| 7 |
+
fsdp_activation_checkpointing: false
|
| 8 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
| 9 |
+
fsdp_backward_prefetch: BACKWARD_PRE
|
| 10 |
+
fsdp_cpu_ram_efficient_loading: true
|
| 11 |
+
fsdp_forward_prefetch: false
|
| 12 |
+
fsdp_offload_params: false
|
| 13 |
+
fsdp_reshard_after_forward: FULL_SHARD
|
| 14 |
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
| 15 |
+
fsdp_sync_module_states: true
|
| 16 |
+
fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock
|
| 17 |
+
fsdp_use_orig_params: true
|
| 18 |
+
fsdp_version: 1
|
| 19 |
+
machine_rank: 0
|
| 20 |
+
main_training_function: main
|
| 21 |
+
mixed_precision: bf16
|
| 22 |
+
num_machines: 1
|
| 23 |
+
num_processes: 4
|
| 24 |
+
rdzv_backend: static
|
| 25 |
+
same_network: true
|
| 26 |
+
tpu_env: []
|
| 27 |
+
tpu_use_cluster: false
|
| 28 |
+
tpu_use_sudo: false
|
| 29 |
+
use_cpu: false
|
packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: FSDP
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
dynamo_config:
|
| 6 |
+
dynamo_backend: INDUCTOR
|
| 7 |
+
dynamo_mode: default
|
| 8 |
+
dynamo_use_fullgraph: false
|
| 9 |
+
dynamo_use_dynamic: true
|
| 10 |
+
enable_cpu_affinity: false
|
| 11 |
+
fsdp_config:
|
| 12 |
+
fsdp_activation_checkpointing: false
|
| 13 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
| 14 |
+
fsdp_backward_prefetch: BACKWARD_PRE
|
| 15 |
+
fsdp_cpu_ram_efficient_loading: true
|
| 16 |
+
fsdp_forward_prefetch: false
|
| 17 |
+
fsdp_offload_params: false
|
| 18 |
+
fsdp_reshard_after_forward: FULL_SHARD
|
| 19 |
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
| 20 |
+
fsdp_sync_module_states: true
|
| 21 |
+
fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock
|
| 22 |
+
fsdp_use_orig_params: true
|
| 23 |
+
fsdp_version: 1
|
| 24 |
+
machine_rank: 0
|
| 25 |
+
main_training_function: main
|
| 26 |
+
mixed_precision: bf16
|
| 27 |
+
num_machines: 1
|
| 28 |
+
num_processes: 4
|
| 29 |
+
rdzv_backend: static
|
| 30 |
+
same_network: true
|
| 31 |
+
tpu_env: []
|
| 32 |
+
tpu_use_cluster: false
|
| 33 |
+
tpu_use_sudo: false
|
| 34 |
+
use_cpu: false
|
packages/ltx-trainer/configs/ltx2_av_lora.yaml
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# LTX-2 Audio-Video LoRA Training Configuration
|
| 3 |
+
# =============================================================================
|
| 4 |
+
#
|
| 5 |
+
# This configuration is for training LoRA adapters on the LTX-2 model for
|
| 6 |
+
# text-to-video generation. It supports both video-only and joint audio-video
|
| 7 |
+
# training modes.
|
| 8 |
+
#
|
| 9 |
+
# Use this configuration when you want to:
|
| 10 |
+
# - Fine-tune LTX-2 on your own video dataset
|
| 11 |
+
# - Train with or without audio generation
|
| 12 |
+
# - Create custom video generation styles or audiovisual concepts
|
| 13 |
+
#
|
| 14 |
+
# Dataset structure for text-to-video training:
|
| 15 |
+
# preprocessed_data_root/
|
| 16 |
+
# ├── latents/ # Video latents (VAE-encoded videos)
|
| 17 |
+
# ├── conditions/ # Text embeddings for each video
|
| 18 |
+
# └── audio_latents/ # Audio latents (only if with_audio: true)
|
| 19 |
+
#
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
# Model Configuration
|
| 24 |
+
# -----------------------------------------------------------------------------
|
| 25 |
+
# Specifies the base model to fine-tune and the training mode.
|
| 26 |
+
model:
|
| 27 |
+
# Path to the LTX-2 model checkpoint (.safetensors file)
|
| 28 |
+
# This should be a local path to your downloaded model
|
| 29 |
+
model_path: "path/to/ltx-2-model.safetensors"
|
| 30 |
+
|
| 31 |
+
# Path to the text encoder model directory
|
| 32 |
+
# For LTX-2, this is typically the Gemma-based text encoder
|
| 33 |
+
text_encoder_path: "path/to/gemma-text-encoder"
|
| 34 |
+
|
| 35 |
+
# Training mode: "lora" for efficient adapter training, "full" for full fine-tuning
|
| 36 |
+
# LoRA is recommended for most use cases (faster, less memory, prevents overfitting)
|
| 37 |
+
training_mode: "lora"
|
| 38 |
+
|
| 39 |
+
# Optional: Path to resume training from a checkpoint
|
| 40 |
+
# Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint)
|
| 41 |
+
load_checkpoint: null
|
| 42 |
+
|
| 43 |
+
# -----------------------------------------------------------------------------
|
| 44 |
+
# LoRA Configuration
|
| 45 |
+
# -----------------------------------------------------------------------------
|
| 46 |
+
# Controls the Low-Rank Adaptation parameters for efficient fine-tuning.
|
| 47 |
+
lora:
|
| 48 |
+
# Rank of the LoRA matrices (higher = more capacity but more parameters)
|
| 49 |
+
# Typical values: 8, 16, 32, 64. Start with 32 for general fine-tuning.
|
| 50 |
+
rank: 32
|
| 51 |
+
|
| 52 |
+
# Alpha scaling factor (usually set equal to rank)
|
| 53 |
+
# The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0
|
| 54 |
+
alpha: 32
|
| 55 |
+
|
| 56 |
+
# Dropout probability for LoRA layers (0.0 = no dropout)
|
| 57 |
+
# Can help with regularization if overfitting occurs
|
| 58 |
+
dropout: 0.0
|
| 59 |
+
|
| 60 |
+
# Which transformer modules to apply LoRA to
|
| 61 |
+
# The LTX-2 transformer has separate attention and FFN blocks for video and audio:
|
| 62 |
+
#
|
| 63 |
+
# VIDEO MODULES:
|
| 64 |
+
# - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention)
|
| 65 |
+
# - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text)
|
| 66 |
+
# - ff.net.0.proj, ff.net.2 (video feed-forward)
|
| 67 |
+
#
|
| 68 |
+
# AUDIO MODULES:
|
| 69 |
+
# - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention)
|
| 70 |
+
# - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text)
|
| 71 |
+
# - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward)
|
| 72 |
+
#
|
| 73 |
+
# AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction):
|
| 74 |
+
# - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0
|
| 75 |
+
# (Q from video, K/V from audio - allows video to attend to audio features)
|
| 76 |
+
# - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0
|
| 77 |
+
# (Q from audio, K/V from video - allows audio to attend to video features)
|
| 78 |
+
#
|
| 79 |
+
# Using short patterns like "to_k" matches ALL attention modules (video, audio, and cross-modal).
|
| 80 |
+
# For audio-video training, this is the recommended approach.
|
| 81 |
+
target_modules:
|
| 82 |
+
# Attention layers (matches both video and audio branches)
|
| 83 |
+
- "to_k"
|
| 84 |
+
- "to_q"
|
| 85 |
+
- "to_v"
|
| 86 |
+
- "to_out.0"
|
| 87 |
+
# Uncomment below to also train feed-forward layers (can increase the LoRA's capacity):
|
| 88 |
+
# - "ff.net.0.proj"
|
| 89 |
+
# - "ff.net.2"
|
| 90 |
+
# - "audio_ff.net.0.proj"
|
| 91 |
+
# - "audio_ff.net.2"
|
| 92 |
+
|
| 93 |
+
# -----------------------------------------------------------------------------
|
| 94 |
+
# Training Strategy Configuration
|
| 95 |
+
# -----------------------------------------------------------------------------
|
| 96 |
+
# Defines the text-to-video training approach.
|
| 97 |
+
training_strategy:
|
| 98 |
+
# Strategy name: "text_to_video" for standard text-to-video training
|
| 99 |
+
name: "text_to_video"
|
| 100 |
+
|
| 101 |
+
# Probability of conditioning on the first frame during training
|
| 102 |
+
# Higher values train the model to perform better in image-to-video (I2V) mode,
|
| 103 |
+
# where a clean first frame is provided and the model generates the rest of the video
|
| 104 |
+
# Increase this value to train the model to perform better in image-to-video (I2V) mode
|
| 105 |
+
first_frame_conditioning_p: 0.5
|
| 106 |
+
|
| 107 |
+
# Enable joint audio-video training
|
| 108 |
+
# Set to true if your dataset includes audio and you want to train the audio branch
|
| 109 |
+
with_audio: true
|
| 110 |
+
|
| 111 |
+
# Directory name (within preprocessed_data_root) containing audio latents
|
| 112 |
+
# Only used when with_audio is true
|
| 113 |
+
audio_latents_dir: "audio_latents"
|
| 114 |
+
|
| 115 |
+
# -----------------------------------------------------------------------------
|
| 116 |
+
# Optimization Configuration
|
| 117 |
+
# -----------------------------------------------------------------------------
|
| 118 |
+
# Controls the training optimization parameters.
|
| 119 |
+
optimization:
|
| 120 |
+
# Learning rate for the optimizer
|
| 121 |
+
# Typical range for LoRA: 1e-5 to 1e-4
|
| 122 |
+
learning_rate: 1e-4
|
| 123 |
+
|
| 124 |
+
# Total number of training steps
|
| 125 |
+
steps: 2000
|
| 126 |
+
|
| 127 |
+
# Batch size per GPU
|
| 128 |
+
# Reduce if running out of memory
|
| 129 |
+
batch_size: 1
|
| 130 |
+
|
| 131 |
+
# Number of gradient accumulation steps
|
| 132 |
+
# Effective batch size = batch_size * gradient_accumulation_steps * num_gpus
|
| 133 |
+
gradient_accumulation_steps: 1
|
| 134 |
+
|
| 135 |
+
# Maximum gradient norm for clipping (helps training stability)
|
| 136 |
+
max_grad_norm: 1.0
|
| 137 |
+
|
| 138 |
+
# Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient)
|
| 139 |
+
optimizer_type: "adamw"
|
| 140 |
+
|
| 141 |
+
# Learning rate scheduler type
|
| 142 |
+
# Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial"
|
| 143 |
+
scheduler_type: "linear"
|
| 144 |
+
|
| 145 |
+
# Additional scheduler parameters (depends on scheduler_type)
|
| 146 |
+
scheduler_params: { }
|
| 147 |
+
|
| 148 |
+
# Enable gradient checkpointing to reduce memory usage
|
| 149 |
+
# Recommended for training with limited GPU memory
|
| 150 |
+
enable_gradient_checkpointing: true
|
| 151 |
+
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
# Acceleration Configuration
|
| 154 |
+
# -----------------------------------------------------------------------------
|
| 155 |
+
# Hardware acceleration and memory optimization settings.
|
| 156 |
+
acceleration:
|
| 157 |
+
# Mixed precision training mode
|
| 158 |
+
# Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended)
|
| 159 |
+
mixed_precision_mode: "bf16"
|
| 160 |
+
|
| 161 |
+
# Model quantization for reduced memory usage
|
| 162 |
+
# Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto"
|
| 163 |
+
quantization: null
|
| 164 |
+
|
| 165 |
+
# Load text encoder in 8-bit precision to save memory
|
| 166 |
+
# Useful when GPU memory is limited
|
| 167 |
+
load_text_encoder_in_8bit: false
|
| 168 |
+
|
| 169 |
+
# -----------------------------------------------------------------------------
|
| 170 |
+
# Data Configuration
|
| 171 |
+
# -----------------------------------------------------------------------------
|
| 172 |
+
# Specifies the training data location and loading parameters.
|
| 173 |
+
data:
|
| 174 |
+
# Root directory containing preprocessed training data
|
| 175 |
+
# Should contain: latents/, conditions/, and optionally audio_latents/
|
| 176 |
+
preprocessed_data_root: "/path/to/preprocessed/data"
|
| 177 |
+
|
| 178 |
+
# Number of worker processes for data loading
|
| 179 |
+
# Used for parallel data loading to speed up data loading
|
| 180 |
+
num_dataloader_workers: 2
|
| 181 |
+
|
| 182 |
+
# -----------------------------------------------------------------------------
|
| 183 |
+
# Validation Configuration
|
| 184 |
+
# -----------------------------------------------------------------------------
|
| 185 |
+
# Controls validation video generation during training.
|
| 186 |
+
# NOTE: Validation sampling use simplified inference pipelines and prioritizes speed over
|
| 187 |
+
# maximum quality. For production-quality inference, use `packages/ltx-pipelines`.
|
| 188 |
+
validation:
|
| 189 |
+
# Text prompts for validation video generation
|
| 190 |
+
# Provide prompts representative of your training data
|
| 191 |
+
# LTX-2 prefers longer, detailed prompts that describe both visual content and audio
|
| 192 |
+
prompts:
|
| 193 |
+
- "A woman with long brown hair sits at a wooden desk in a cozy home office, typing on a laptop while occasionally glancing at notes beside her. Soft natural light streams through a large window, casting warm shadows across the room. She pauses to take a sip from a ceramic mug, then continues working with focused concentration. The audio captures the gentle clicking of keyboard keys, the soft rustle of papers, and ambient room tone with occasional distant bird chirps from outside."
|
| 194 |
+
- "A chef in a white uniform stands in a professional kitchen, carefully plating a gourmet dish with precise movements. Steam rises from freshly cooked vegetables as he arranges them with tweezers. The stainless steel surfaces gleam under bright overhead lights, and various pots simmer on the stove behind him. The audio features the sizzling of pans, the clinking of utensils against plates, and the ambient hum of kitchen ventilation."
|
| 195 |
+
|
| 196 |
+
# Negative prompt to avoid unwanted artifacts
|
| 197 |
+
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
|
| 198 |
+
|
| 199 |
+
# Optional: First frame images for image-to-video validation
|
| 200 |
+
# If provided, must have one image per prompt
|
| 201 |
+
images: null
|
| 202 |
+
|
| 203 |
+
# Output video dimensions [width, height, frames]
|
| 204 |
+
# Width and height must be divisible by 32
|
| 205 |
+
# Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...)
|
| 206 |
+
video_dims: [ 576, 576, 89 ]
|
| 207 |
+
|
| 208 |
+
# Frame rate for generated videos
|
| 209 |
+
frame_rate: 25.0
|
| 210 |
+
|
| 211 |
+
# Random seed for reproducible validation outputs
|
| 212 |
+
seed: 42
|
| 213 |
+
|
| 214 |
+
# Number of denoising steps for validation inference
|
| 215 |
+
# Higher values = better quality but slower generation
|
| 216 |
+
inference_steps: 30
|
| 217 |
+
|
| 218 |
+
# Generate validation videos every N training steps
|
| 219 |
+
# Set to null to disable validation during training
|
| 220 |
+
interval: 100
|
| 221 |
+
|
| 222 |
+
# Number of videos to generate per prompt
|
| 223 |
+
videos_per_prompt: 1
|
| 224 |
+
|
| 225 |
+
# Classifier-free guidance scale
|
| 226 |
+
# Higher values = stronger adherence to prompt but may introduce artifacts
|
| 227 |
+
guidance_scale: 4.0
|
| 228 |
+
|
| 229 |
+
# STG (Spatio-Temporal Guidance) parameters for improved video quality
|
| 230 |
+
# STG is combined with CFG for better temporal coherence
|
| 231 |
+
stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG)
|
| 232 |
+
stg_blocks: [29] # Recommended: single block 29
|
| 233 |
+
stg_mode: "stg_av" # "stg_av" perturbs both audio and video, "stg_v" video only
|
| 234 |
+
|
| 235 |
+
# Whether to generate audio in validation samples
|
| 236 |
+
# Independent of training_strategy.with_audio - you can generate audio
|
| 237 |
+
# in validation even when not training the audio branch
|
| 238 |
+
generate_audio: true
|
| 239 |
+
|
| 240 |
+
# Skip validation at the beginning of training (step 0)
|
| 241 |
+
skip_initial_validation: false
|
| 242 |
+
|
| 243 |
+
# -----------------------------------------------------------------------------
|
| 244 |
+
# Checkpoint Configuration
|
| 245 |
+
# -----------------------------------------------------------------------------
|
| 246 |
+
# Controls model checkpoint saving during training.
|
| 247 |
+
checkpoints:
|
| 248 |
+
# Save a checkpoint every N steps
|
| 249 |
+
# Set to null to disable intermediate checkpoints
|
| 250 |
+
interval: 250
|
| 251 |
+
|
| 252 |
+
# Number of most recent checkpoints to keep
|
| 253 |
+
# Set to -1 to keep all checkpoints
|
| 254 |
+
keep_last_n: -1
|
| 255 |
+
|
| 256 |
+
# Precision to use when saving checkpoint weights
|
| 257 |
+
# Options: "bfloat16" (default, smaller files) or "float32" (full precision)
|
| 258 |
+
precision: "bfloat16"
|
| 259 |
+
|
| 260 |
+
# -----------------------------------------------------------------------------
|
| 261 |
+
# Flow Matching Configuration
|
| 262 |
+
# -----------------------------------------------------------------------------
|
| 263 |
+
# Parameters for the flow matching training objective.
|
| 264 |
+
flow_matching:
|
| 265 |
+
# Timestep sampling mode
|
| 266 |
+
# "shifted_logit_normal" is recommended for LTX-2 models
|
| 267 |
+
timestep_sampling_mode: "shifted_logit_normal"
|
| 268 |
+
|
| 269 |
+
# Additional parameters for timestep sampling
|
| 270 |
+
timestep_sampling_params: { }
|
| 271 |
+
|
| 272 |
+
# -----------------------------------------------------------------------------
|
| 273 |
+
# Hugging Face Hub Configuration
|
| 274 |
+
# -----------------------------------------------------------------------------
|
| 275 |
+
# Settings for uploading trained models to the Hugging Face Hub.
|
| 276 |
+
hub:
|
| 277 |
+
# Whether to push the trained model to the Hub
|
| 278 |
+
push_to_hub: false
|
| 279 |
+
|
| 280 |
+
# Repository ID on Hugging Face Hub (e.g., "username/my-lora-model")
|
| 281 |
+
# Required if push_to_hub is true
|
| 282 |
+
hub_model_id: null
|
| 283 |
+
|
| 284 |
+
# -----------------------------------------------------------------------------
|
| 285 |
+
# Weights & Biases Configuration
|
| 286 |
+
# -----------------------------------------------------------------------------
|
| 287 |
+
# Settings for experiment tracking with W&B.
|
| 288 |
+
wandb:
|
| 289 |
+
# Enable W&B logging
|
| 290 |
+
enabled: false
|
| 291 |
+
|
| 292 |
+
# W&B project name
|
| 293 |
+
project: "ltx-2-trainer"
|
| 294 |
+
|
| 295 |
+
# W&B username or team (null uses default account)
|
| 296 |
+
entity: null
|
| 297 |
+
|
| 298 |
+
# Tags to help organize runs
|
| 299 |
+
tags: [ "ltx2", "lora" ]
|
| 300 |
+
|
| 301 |
+
# Log validation videos to W&B
|
| 302 |
+
log_validation_videos: true
|
| 303 |
+
|
| 304 |
+
# -----------------------------------------------------------------------------
|
| 305 |
+
# General Configuration
|
| 306 |
+
# -----------------------------------------------------------------------------
|
| 307 |
+
# Global settings for the training run.
|
| 308 |
+
|
| 309 |
+
# Random seed for reproducibility
|
| 310 |
+
seed: 42
|
| 311 |
+
|
| 312 |
+
# Directory to save outputs (checkpoints, validation videos, logs)
|
| 313 |
+
output_dir: "outputs/ltx2_av_lora"
|
packages/ltx-trainer/configs/ltx2_av_lora_low_vram.yaml
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# LTX-2 Audio-Video LoRA Training Configuration (Low VRAM)
|
| 3 |
+
# =============================================================================
|
| 4 |
+
#
|
| 5 |
+
# This is a memory-optimized variant of the standard audio-video LoRA config.
|
| 6 |
+
# It uses 8-bit optimizer, int8 quantization, and reduced LoRA rank to minimize
|
| 7 |
+
# GPU memory usage while maintaining good training quality.
|
| 8 |
+
#
|
| 9 |
+
# Memory optimizations applied:
|
| 10 |
+
# - 8-bit AdamW optimizer (reduces optimizer state memory by ~75%)
|
| 11 |
+
# - INT8 model quantization (reduces model memory by ~50%)
|
| 12 |
+
# - Lower LoRA rank (16 vs 32, reduces trainable parameters)
|
| 13 |
+
# - Gradient checkpointing enabled
|
| 14 |
+
#
|
| 15 |
+
# Recommended for GPUs with 32GB VRAM (e.g., RTX 5090).
|
| 16 |
+
#
|
| 17 |
+
# Use this configuration when you want to:
|
| 18 |
+
# - Fine-tune LTX-2 on your own video dataset
|
| 19 |
+
# - Train with or without audio generation
|
| 20 |
+
# - Create custom video generation styles or audiovisual concepts
|
| 21 |
+
#
|
| 22 |
+
# Dataset structure for text-to-video training:
|
| 23 |
+
# preprocessed_data_root/
|
| 24 |
+
# ├── latents/ # Video latents (VAE-encoded videos)
|
| 25 |
+
# ├── conditions/ # Text embeddings for each video
|
| 26 |
+
# └── audio_latents/ # Audio latents (only if with_audio: true)
|
| 27 |
+
#
|
| 28 |
+
# =============================================================================
|
| 29 |
+
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
# Model Configuration
|
| 32 |
+
# -----------------------------------------------------------------------------
|
| 33 |
+
# Specifies the base model to fine-tune and the training mode.
|
| 34 |
+
model:
|
| 35 |
+
# Path to the LTX-2 model checkpoint (.safetensors file)
|
| 36 |
+
# This should be a local path to your downloaded model
|
| 37 |
+
model_path: "path/to/ltx-2-model.safetensors"
|
| 38 |
+
|
| 39 |
+
# Path to the text encoder model directory
|
| 40 |
+
# For LTX-2, this is typically the Gemma-based text encoder
|
| 41 |
+
text_encoder_path: "path/to/gemma-text-encoder"
|
| 42 |
+
|
| 43 |
+
# Training mode: "lora" for efficient adapter training, "full" for full fine-tuning
|
| 44 |
+
# LoRA is recommended for most use cases (faster, less memory, prevents overfitting)
|
| 45 |
+
training_mode: "lora"
|
| 46 |
+
|
| 47 |
+
# Optional: Path to resume training from a checkpoint
|
| 48 |
+
# Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint)
|
| 49 |
+
load_checkpoint: null
|
| 50 |
+
|
| 51 |
+
# -----------------------------------------------------------------------------
|
| 52 |
+
# LoRA Configuration
|
| 53 |
+
# -----------------------------------------------------------------------------
|
| 54 |
+
# Controls the Low-Rank Adaptation parameters for efficient fine-tuning.
|
| 55 |
+
# Using a lower rank (16) to reduce trainable parameters and memory usage.
|
| 56 |
+
# This still provides good capacity for many fine-tuning tasks.
|
| 57 |
+
lora:
|
| 58 |
+
# Rank of the LoRA matrices (higher = more capacity but more parameters)
|
| 59 |
+
# Typical values: 8, 16, 32, 64. Using 16 for low VRAM configuration.
|
| 60 |
+
rank: 16
|
| 61 |
+
|
| 62 |
+
# Alpha scaling factor (usually set equal to rank)
|
| 63 |
+
# The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0
|
| 64 |
+
alpha: 16
|
| 65 |
+
|
| 66 |
+
# Dropout probability for LoRA layers (0.0 = no dropout)
|
| 67 |
+
# Can help with regularization if overfitting occurs
|
| 68 |
+
dropout: 0.0
|
| 69 |
+
|
| 70 |
+
# Which transformer modules to apply LoRA to
|
| 71 |
+
# The LTX-2 transformer has separate attention and FFN blocks for video and audio:
|
| 72 |
+
#
|
| 73 |
+
# VIDEO MODULES:
|
| 74 |
+
# - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention)
|
| 75 |
+
# - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text)
|
| 76 |
+
# - ff.net.0.proj, ff.net.2 (video feed-forward)
|
| 77 |
+
#
|
| 78 |
+
# AUDIO MODULES:
|
| 79 |
+
# - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention)
|
| 80 |
+
# - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text)
|
| 81 |
+
# - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward)
|
| 82 |
+
#
|
| 83 |
+
# AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction):
|
| 84 |
+
# - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0
|
| 85 |
+
# (Q from video, K/V from audio - allows video to attend to audio features)
|
| 86 |
+
# - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0
|
| 87 |
+
# (Q from audio, K/V from video - allows audio to attend to video features)
|
| 88 |
+
#
|
| 89 |
+
# Using short patterns like "to_k" matches ALL attention modules (video, audio, and cross-modal).
|
| 90 |
+
# For audio-video training, this is the recommended approach.
|
| 91 |
+
target_modules:
|
| 92 |
+
# Attention layers (matches both video and audio branches)
|
| 93 |
+
- "to_k"
|
| 94 |
+
- "to_q"
|
| 95 |
+
- "to_v"
|
| 96 |
+
- "to_out.0"
|
| 97 |
+
# Uncomment below to also train feed-forward layers (can increase the LoRA's capacity):
|
| 98 |
+
# - "ff.net.0.proj"
|
| 99 |
+
# - "ff.net.2"
|
| 100 |
+
# - "audio_ff.net.0.proj"
|
| 101 |
+
# - "audio_ff.net.2"
|
| 102 |
+
|
| 103 |
+
# -----------------------------------------------------------------------------
|
| 104 |
+
# Training Strategy Configuration
|
| 105 |
+
# -----------------------------------------------------------------------------
|
| 106 |
+
# Defines the text-to-video training approach.
|
| 107 |
+
training_strategy:
|
| 108 |
+
# Strategy name: "text_to_video" for standard text-to-video training
|
| 109 |
+
name: "text_to_video"
|
| 110 |
+
|
| 111 |
+
# Probability of conditioning on the first frame during training
|
| 112 |
+
# Higher values train the model to perform better in image-to-video (I2V) mode,
|
| 113 |
+
# where a clean first frame is provided and the model generates the rest of the video
|
| 114 |
+
# Increase this value to train the model to perform better in image-to-video (I2V) mode
|
| 115 |
+
first_frame_conditioning_p: 0.5
|
| 116 |
+
|
| 117 |
+
# Enable joint audio-video training
|
| 118 |
+
# Set to true if your dataset includes audio and you want to train the audio branch
|
| 119 |
+
with_audio: true
|
| 120 |
+
|
| 121 |
+
# Directory name (within preprocessed_data_root) containing audio latents
|
| 122 |
+
# Only used when with_audio is true
|
| 123 |
+
audio_latents_dir: "audio_latents"
|
| 124 |
+
|
| 125 |
+
# -----------------------------------------------------------------------------
|
| 126 |
+
# Optimization Configuration
|
| 127 |
+
# -----------------------------------------------------------------------------
|
| 128 |
+
# Controls the training optimization parameters.
|
| 129 |
+
optimization:
|
| 130 |
+
# Learning rate for the optimizer
|
| 131 |
+
# Typical range for LoRA: 1e-5 to 1e-4
|
| 132 |
+
learning_rate: 1e-4
|
| 133 |
+
|
| 134 |
+
# Total number of training steps
|
| 135 |
+
steps: 2000
|
| 136 |
+
|
| 137 |
+
# Batch size per GPU
|
| 138 |
+
# Reduce if running out of memory
|
| 139 |
+
batch_size: 1
|
| 140 |
+
|
| 141 |
+
# Number of gradient accumulation steps
|
| 142 |
+
# Effective batch size = batch_size * gradient_accumulation_steps * num_gpus
|
| 143 |
+
gradient_accumulation_steps: 1
|
| 144 |
+
|
| 145 |
+
# Maximum gradient norm for clipping (helps training stability)
|
| 146 |
+
max_grad_norm: 1.0
|
| 147 |
+
|
| 148 |
+
# Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient)
|
| 149 |
+
# Using 8-bit AdamW to reduce optimizer state memory by ~75%
|
| 150 |
+
optimizer_type: "adamw8bit"
|
| 151 |
+
|
| 152 |
+
# Learning rate scheduler type
|
| 153 |
+
# Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial"
|
| 154 |
+
scheduler_type: "linear"
|
| 155 |
+
|
| 156 |
+
# Additional scheduler parameters (depends on scheduler_type)
|
| 157 |
+
scheduler_params: { }
|
| 158 |
+
|
| 159 |
+
# Enable gradient checkpointing to reduce memory usage
|
| 160 |
+
# Recommended for training with limited GPU memory
|
| 161 |
+
enable_gradient_checkpointing: true
|
| 162 |
+
|
| 163 |
+
# -----------------------------------------------------------------------------
|
| 164 |
+
# Acceleration Configuration
|
| 165 |
+
# -----------------------------------------------------------------------------
|
| 166 |
+
# Hardware acceleration and memory optimization settings.
|
| 167 |
+
acceleration:
|
| 168 |
+
# Mixed precision training mode
|
| 169 |
+
# Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended)
|
| 170 |
+
mixed_precision_mode: "bf16"
|
| 171 |
+
|
| 172 |
+
# Model quantization for reduced memory usage
|
| 173 |
+
# Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto"
|
| 174 |
+
# Using INT8 quantization to reduce base model memory consumption by ~50%
|
| 175 |
+
quantization: "int8-quanto"
|
| 176 |
+
|
| 177 |
+
# Load text encoder in 8-bit precision to save memory
|
| 178 |
+
# Useful when GPU memory is limited
|
| 179 |
+
load_text_encoder_in_8bit: true
|
| 180 |
+
|
| 181 |
+
# -----------------------------------------------------------------------------
|
| 182 |
+
# Data Configuration
|
| 183 |
+
# -----------------------------------------------------------------------------
|
| 184 |
+
# Specifies the training data location and loading parameters.
|
| 185 |
+
data:
|
| 186 |
+
# Root directory containing preprocessed training data
|
| 187 |
+
# Should contain: latents/, conditions/, and optionally audio_latents/
|
| 188 |
+
preprocessed_data_root: "/path/to/preprocessed/data"
|
| 189 |
+
|
| 190 |
+
# Number of worker processes for data loading
|
| 191 |
+
# Used for parallel data loading to speed up data loading
|
| 192 |
+
num_dataloader_workers: 2
|
| 193 |
+
|
| 194 |
+
# -----------------------------------------------------------------------------
|
| 195 |
+
# Validation Configuration
|
| 196 |
+
# -----------------------------------------------------------------------------
|
| 197 |
+
# Controls validation video generation during training.
|
| 198 |
+
# NOTE: Validation sampling use simplified inference pipelines and prioritizes speed over
|
| 199 |
+
# maximum quality. For production-quality inference, use `packages/ltx-pipelines`.
|
| 200 |
+
validation:
|
| 201 |
+
# Text prompts for validation video generation
|
| 202 |
+
# Provide prompts representative of your training data
|
| 203 |
+
# LTX-2 prefers longer, detailed prompts that describe both visual content and audio
|
| 204 |
+
prompts:
|
| 205 |
+
- "A woman with long brown hair sits at a wooden desk in a cozy home office, typing on a laptop while occasionally glancing at notes beside her. Soft natural light streams through a large window, casting warm shadows across the room. She pauses to take a sip from a ceramic mug, then continues working with focused concentration. The audio captures the gentle clicking of keyboard keys, the soft rustle of papers, and ambient room tone with occasional distant bird chirps from outside."
|
| 206 |
+
- "A chef in a white uniform stands in a professional kitchen, carefully plating a gourmet dish with precise movements. Steam rises from freshly cooked vegetables as he arranges them with tweezers. The stainless steel surfaces gleam under bright overhead lights, and various pots simmer on the stove behind him. The audio features the sizzling of pans, the clinking of utensils against plates, and the ambient hum of kitchen ventilation."
|
| 207 |
+
|
| 208 |
+
# Negative prompt to avoid unwanted artifacts
|
| 209 |
+
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
|
| 210 |
+
|
| 211 |
+
# Optional: First frame images for image-to-video validation
|
| 212 |
+
# If provided, must have one image per prompt
|
| 213 |
+
images: null
|
| 214 |
+
|
| 215 |
+
# Output video dimensions [width, height, frames]
|
| 216 |
+
# Width and height must be divisible by 32
|
| 217 |
+
# Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...)
|
| 218 |
+
video_dims: [ 576, 576, 49 ]
|
| 219 |
+
|
| 220 |
+
# Frame rate for generated videos
|
| 221 |
+
frame_rate: 25.0
|
| 222 |
+
|
| 223 |
+
# Random seed for reproducible validation outputs
|
| 224 |
+
seed: 42
|
| 225 |
+
|
| 226 |
+
# Number of denoising steps for validation inference
|
| 227 |
+
# Higher values = better quality but slower generation
|
| 228 |
+
inference_steps: 30
|
| 229 |
+
|
| 230 |
+
# Generate validation videos every N training steps
|
| 231 |
+
# Set to null to disable validation during training
|
| 232 |
+
interval: 100
|
| 233 |
+
|
| 234 |
+
# Number of videos to generate per prompt
|
| 235 |
+
videos_per_prompt: 1
|
| 236 |
+
|
| 237 |
+
# Classifier-free guidance scale
|
| 238 |
+
# Higher values = stronger adherence to prompt but may introduce artifacts
|
| 239 |
+
guidance_scale: 4.0
|
| 240 |
+
|
| 241 |
+
# STG (Spatio-Temporal Guidance) parameters for improved video quality
|
| 242 |
+
# STG is combined with CFG for better temporal coherence
|
| 243 |
+
stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG)
|
| 244 |
+
stg_blocks: [ 29 ] # Recommended: single block 29
|
| 245 |
+
stg_mode: "stg_av" # "stg_av" perturbs both audio and video, "stg_v" video only
|
| 246 |
+
|
| 247 |
+
# Whether to generate audio in validation samples
|
| 248 |
+
# Independent of training_strategy.with_audio - you can generate audio
|
| 249 |
+
# in validation even when not training the audio branch
|
| 250 |
+
generate_audio: true
|
| 251 |
+
|
| 252 |
+
# Skip validation at the beginning of training (step 0)
|
| 253 |
+
skip_initial_validation: false
|
| 254 |
+
|
| 255 |
+
# -----------------------------------------------------------------------------
|
| 256 |
+
# Checkpoint Configuration
|
| 257 |
+
# -----------------------------------------------------------------------------
|
| 258 |
+
# Controls model checkpoint saving during training.
|
| 259 |
+
checkpoints:
|
| 260 |
+
# Save a checkpoint every N steps
|
| 261 |
+
# Set to null to disable intermediate checkpoints
|
| 262 |
+
interval: 250
|
| 263 |
+
|
| 264 |
+
# Number of most recent checkpoints to keep
|
| 265 |
+
# Set to -1 to keep all checkpoints
|
| 266 |
+
keep_last_n: -1
|
| 267 |
+
|
| 268 |
+
# Precision to use when saving checkpoint weights
|
| 269 |
+
# Options: "bfloat16" (default, smaller files) or "float32" (full precision)
|
| 270 |
+
precision: "bfloat16"
|
| 271 |
+
|
| 272 |
+
# -----------------------------------------------------------------------------
|
| 273 |
+
# Flow Matching Configuration
|
| 274 |
+
# -----------------------------------------------------------------------------
|
| 275 |
+
# Parameters for the flow matching training objective.
|
| 276 |
+
flow_matching:
|
| 277 |
+
# Timestep sampling mode
|
| 278 |
+
# "shifted_logit_normal" is recommended for LTX-2 models
|
| 279 |
+
timestep_sampling_mode: "shifted_logit_normal"
|
| 280 |
+
|
| 281 |
+
# Additional parameters for timestep sampling
|
| 282 |
+
timestep_sampling_params: { }
|
| 283 |
+
|
| 284 |
+
# -----------------------------------------------------------------------------
|
| 285 |
+
# Hugging Face Hub Configuration
|
| 286 |
+
# -----------------------------------------------------------------------------
|
| 287 |
+
# Settings for uploading trained models to the Hugging Face Hub.
|
| 288 |
+
hub:
|
| 289 |
+
# Whether to push the trained model to the Hub
|
| 290 |
+
push_to_hub: false
|
| 291 |
+
|
| 292 |
+
# Repository ID on Hugging Face Hub (e.g., "username/my-lora-model")
|
| 293 |
+
# Required if push_to_hub is true
|
| 294 |
+
hub_model_id: null
|
| 295 |
+
|
| 296 |
+
# -----------------------------------------------------------------------------
|
| 297 |
+
# Weights & Biases Configuration
|
| 298 |
+
# -----------------------------------------------------------------------------
|
| 299 |
+
# Settings for experiment tracking with W&B.
|
| 300 |
+
wandb:
|
| 301 |
+
# Enable W&B logging
|
| 302 |
+
enabled: false
|
| 303 |
+
|
| 304 |
+
# W&B project name
|
| 305 |
+
project: "ltx-2-trainer"
|
| 306 |
+
|
| 307 |
+
# W&B username or team (null uses default account)
|
| 308 |
+
entity: null
|
| 309 |
+
|
| 310 |
+
# Tags to help organize runs
|
| 311 |
+
tags: [ "ltx2", "lora" ]
|
| 312 |
+
|
| 313 |
+
# Log validation videos to W&B
|
| 314 |
+
log_validation_videos: true
|
| 315 |
+
|
| 316 |
+
# -----------------------------------------------------------------------------
|
| 317 |
+
# General Configuration
|
| 318 |
+
# -----------------------------------------------------------------------------
|
| 319 |
+
# Global settings for the training run.
|
| 320 |
+
|
| 321 |
+
# Random seed for reproducibility
|
| 322 |
+
seed: 42
|
| 323 |
+
|
| 324 |
+
# Directory to save outputs (checkpoints, validation videos, logs)
|
| 325 |
+
output_dir: "outputs/ltx2_av_lora"
|
packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# LTX-2 Video-to-Video (IC-LoRA) Training Configuration
|
| 3 |
+
# =============================================================================
|
| 4 |
+
#
|
| 5 |
+
# This configuration is for training In-Context LoRA (IC-LoRA) adapters that
|
| 6 |
+
# enable video-to-video transformations. IC-LoRA learns to apply visual
|
| 7 |
+
# transformations (e.g., depth-to-video, pose control, style transfer, etc.)
|
| 8 |
+
# by conditioning on reference videos.
|
| 9 |
+
#
|
| 10 |
+
# Key differences from text-to-video LoRA:
|
| 11 |
+
# - Uses reference videos as conditioning input alongside text prompts
|
| 12 |
+
# - Requires preprocessed reference latents in addition to target latents
|
| 13 |
+
# - Validation requires reference videos to demonstrate the transformation
|
| 14 |
+
#
|
| 15 |
+
# Dataset structure for IC-LoRA training:
|
| 16 |
+
# preprocessed_data_root/
|
| 17 |
+
# ├── latents/ # Target video latents (what the model learns to generate)
|
| 18 |
+
# ├── conditions/ # Text embeddings for each video
|
| 19 |
+
# └── reference_latents/ # Reference video latents (conditioning input)
|
| 20 |
+
#
|
| 21 |
+
# =============================================================================
|
| 22 |
+
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
# Model Configuration
|
| 25 |
+
# -----------------------------------------------------------------------------
|
| 26 |
+
# Specifies the base model to fine-tune and the training mode.
|
| 27 |
+
model:
|
| 28 |
+
# Path to the LTX-2 model checkpoint (.safetensors file)
|
| 29 |
+
# This should be a local path to your downloaded model
|
| 30 |
+
model_path: "path/to/ltx-2-model.safetensors"
|
| 31 |
+
|
| 32 |
+
# Path to the text encoder model directory
|
| 33 |
+
# For LTX-2, this is typically the Gemma-based text encoder
|
| 34 |
+
text_encoder_path: "path/to/gemma-text-encoder"
|
| 35 |
+
|
| 36 |
+
# Training mode: "lora" for efficient adapter training, "full" for full fine-tuning
|
| 37 |
+
# Note: video_to_video strategy requires "lora" mode
|
| 38 |
+
training_mode: "lora"
|
| 39 |
+
|
| 40 |
+
# Optional: Path to resume training from a checkpoint
|
| 41 |
+
# Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint)
|
| 42 |
+
load_checkpoint: null
|
| 43 |
+
|
| 44 |
+
# -----------------------------------------------------------------------------
|
| 45 |
+
# LoRA Configuration
|
| 46 |
+
# -----------------------------------------------------------------------------
|
| 47 |
+
# Controls the Low-Rank Adaptation parameters for efficient fine-tuning.
|
| 48 |
+
lora:
|
| 49 |
+
# Rank of the LoRA matrices (higher = more capacity but more parameters)
|
| 50 |
+
# Typical values: 8, 16, 32, 64. Start with 16-32 for IC-LoRA.
|
| 51 |
+
rank: 32
|
| 52 |
+
|
| 53 |
+
# Alpha scaling factor (usually set equal to rank)
|
| 54 |
+
# The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0
|
| 55 |
+
alpha: 32
|
| 56 |
+
|
| 57 |
+
# Dropout probability for LoRA layers (0.0 = no dropout)
|
| 58 |
+
# Can help with regularization if overfitting occurs
|
| 59 |
+
dropout: 0.0
|
| 60 |
+
|
| 61 |
+
# Which transformer modules to apply LoRA to
|
| 62 |
+
# The LTX-2 transformer has separate attention and FFN blocks for video and audio:
|
| 63 |
+
#
|
| 64 |
+
# VIDEO MODULES:
|
| 65 |
+
# - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention)
|
| 66 |
+
# - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text)
|
| 67 |
+
# - ff.net.0.proj, ff.net.2 (video feed-forward)
|
| 68 |
+
#
|
| 69 |
+
# AUDIO MODULES (not used for video-only IC-LoRA):
|
| 70 |
+
# - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention)
|
| 71 |
+
# - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text)
|
| 72 |
+
# - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward)
|
| 73 |
+
#
|
| 74 |
+
# AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction, not used for video-only IC-LoRA):
|
| 75 |
+
# - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0
|
| 76 |
+
# (Q from video, K/V from audio - allows video to attend to audio features)
|
| 77 |
+
# - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0
|
| 78 |
+
# (Q from audio, K/V from video - allows audio to attend to video features)
|
| 79 |
+
#
|
| 80 |
+
# For IC-LoRA (video-only), we explicitly target video modules.
|
| 81 |
+
# Including FFN layers often improves transformation quality.
|
| 82 |
+
target_modules:
|
| 83 |
+
# Video self-attention
|
| 84 |
+
- "attn1.to_k"
|
| 85 |
+
- "attn1.to_q"
|
| 86 |
+
- "attn1.to_v"
|
| 87 |
+
- "attn1.to_out.0"
|
| 88 |
+
# Video cross-attention
|
| 89 |
+
- "attn2.to_k"
|
| 90 |
+
- "attn2.to_q"
|
| 91 |
+
- "attn2.to_v"
|
| 92 |
+
- "attn2.to_out.0"
|
| 93 |
+
# Video feed-forward (often improves transformation quality)
|
| 94 |
+
- "ff.net.0.proj"
|
| 95 |
+
- "ff.net.2"
|
| 96 |
+
|
| 97 |
+
# -----------------------------------------------------------------------------
|
| 98 |
+
# Training Strategy Configuration
|
| 99 |
+
# -----------------------------------------------------------------------------
|
| 100 |
+
# Defines the video-to-video (IC-LoRA) training approach.
|
| 101 |
+
training_strategy:
|
| 102 |
+
# Strategy name: "video_to_video" for IC-LoRA training
|
| 103 |
+
name: "video_to_video"
|
| 104 |
+
|
| 105 |
+
# Probability of conditioning on the first frame during training
|
| 106 |
+
# Higher values train the model to perform better in image-to-video (I2V) mode,
|
| 107 |
+
# where a clean first frame is provided and the model generates the rest of the video
|
| 108 |
+
# Increase this value to train the model to perform better in image-to-video (I2V) mode
|
| 109 |
+
first_frame_conditioning_p: 0.2
|
| 110 |
+
|
| 111 |
+
# Directory name (within preprocessed_data_root) containing reference video latents
|
| 112 |
+
# These are the conditioning inputs that guide the transformation
|
| 113 |
+
reference_latents_dir: "reference_latents"
|
| 114 |
+
|
| 115 |
+
# -----------------------------------------------------------------------------
|
| 116 |
+
# Optimization Configuration
|
| 117 |
+
# -----------------------------------------------------------------------------
|
| 118 |
+
# Controls the training optimization parameters.
|
| 119 |
+
optimization:
|
| 120 |
+
# Learning rate for the optimizer
|
| 121 |
+
# Typical range for LoRA: 1e-5 to 1e-4
|
| 122 |
+
learning_rate: 2e-4
|
| 123 |
+
|
| 124 |
+
# Total number of training steps
|
| 125 |
+
steps: 3000
|
| 126 |
+
|
| 127 |
+
# Batch size per GPU
|
| 128 |
+
# Reduce if running out of memory
|
| 129 |
+
batch_size: 1
|
| 130 |
+
|
| 131 |
+
# Number of gradient accumulation steps
|
| 132 |
+
# Effective batch size = batch_size * gradient_accumulation_steps * num_gpus
|
| 133 |
+
gradient_accumulation_steps: 1
|
| 134 |
+
|
| 135 |
+
# Maximum gradient norm for clipping (helps training stability)
|
| 136 |
+
max_grad_norm: 1.0
|
| 137 |
+
|
| 138 |
+
# Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient)
|
| 139 |
+
optimizer_type: "adamw"
|
| 140 |
+
|
| 141 |
+
# Learning rate scheduler type
|
| 142 |
+
# Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial"
|
| 143 |
+
scheduler_type: "linear"
|
| 144 |
+
|
| 145 |
+
# Additional scheduler parameters (depends on scheduler_type)
|
| 146 |
+
scheduler_params: { }
|
| 147 |
+
|
| 148 |
+
# Enable gradient checkpointing to reduce memory usage
|
| 149 |
+
# Recommended for training with limited GPU memory
|
| 150 |
+
enable_gradient_checkpointing: true
|
| 151 |
+
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
# Acceleration Configuration
|
| 154 |
+
# -----------------------------------------------------------------------------
|
| 155 |
+
# Hardware acceleration and memory optimization settings.
|
| 156 |
+
acceleration:
|
| 157 |
+
# Mixed precision training mode
|
| 158 |
+
# Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended)
|
| 159 |
+
mixed_precision_mode: "bf16"
|
| 160 |
+
|
| 161 |
+
# Model quantization for reduced memory usage
|
| 162 |
+
# Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto"
|
| 163 |
+
quantization: null
|
| 164 |
+
|
| 165 |
+
# Load text encoder in 8-bit precision to save memory
|
| 166 |
+
# Useful when GPU memory is limited
|
| 167 |
+
load_text_encoder_in_8bit: false
|
| 168 |
+
|
| 169 |
+
# -----------------------------------------------------------------------------
|
| 170 |
+
# Data Configuration
|
| 171 |
+
# -----------------------------------------------------------------------------
|
| 172 |
+
# Specifies the training data location and loading parameters.
|
| 173 |
+
data:
|
| 174 |
+
# Root directory containing preprocessed training data
|
| 175 |
+
# Should contain: latents/, conditions/, and reference_latents/ subdirectories
|
| 176 |
+
preprocessed_data_root: "/path/to/preprocessed/data"
|
| 177 |
+
|
| 178 |
+
# Number of worker processes for data loading
|
| 179 |
+
# Used for parallel data loading to speed up data loading
|
| 180 |
+
num_dataloader_workers: 2
|
| 181 |
+
|
| 182 |
+
# -----------------------------------------------------------------------------
|
| 183 |
+
# Validation Configuration
|
| 184 |
+
# -----------------------------------------------------------------------------
|
| 185 |
+
# Controls validation video generation during training.
|
| 186 |
+
# NOTE: Validation sampling use simplified inference pipelines and prioritizes speed over
|
| 187 |
+
# maximum quality. For production-quality inference, use `packages/ltx-pipelines`.
|
| 188 |
+
validation:
|
| 189 |
+
# Text prompts for validation video generation
|
| 190 |
+
# Provide prompts representative of your training data
|
| 191 |
+
# LTX-2 prefers longer, detailed prompts that describe both visual content and audio
|
| 192 |
+
prompts:
|
| 193 |
+
- "A man in a casual blue jacket walks along a winding path through a lush green park on a bright sunny afternoon. Tall oak trees line the pathway, their leaves rustling gently in the breeze. Dappled sunlight creates shifting patterns on the ground as he strolls at a relaxed pace, occasionally looking up at the scenery around him. The audio captures footsteps on gravel, birds singing in the trees, distant children playing, and the soft whisper of wind through the foliage."
|
| 194 |
+
- "A fluffy orange tabby cat sits perfectly still on a wooden windowsill, its green eyes intently tracking small birds hopping on a branch just outside the glass. The cat's ears twitch and rotate, following every movement. Warm afternoon light illuminates its fur, creating a soft golden glow. Behind the cat, a cozy living room with a bookshelf and houseplants is visible. The audio features gentle purring, occasional soft meows, muffled bird chirps through the window, and quiet ambient room sounds."
|
| 195 |
+
|
| 196 |
+
# Reference videos for validation (REQUIRED for video_to_video strategy)
|
| 197 |
+
# Must provide one reference video per prompt
|
| 198 |
+
# These are the conditioning inputs for generating validation outputs
|
| 199 |
+
reference_videos:
|
| 200 |
+
- "/path/to/reference_video_1.mp4"
|
| 201 |
+
- "/path/to/reference_video_2.mp4"
|
| 202 |
+
|
| 203 |
+
# Downscale factor for reference videos (for efficient IC-LoRA training)
|
| 204 |
+
# When > 1, reference videos are processed at 1/n resolution
|
| 205 |
+
# Must match the --reference-downscale-factor used during dataset preprocessing
|
| 206 |
+
# Examples: 1 = same resolution, 2 = half resolution (384x384 ref for 768x768 target)
|
| 207 |
+
reference_downscale_factor: 1
|
| 208 |
+
|
| 209 |
+
# Negative prompt to avoid unwanted artifacts
|
| 210 |
+
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
|
| 211 |
+
|
| 212 |
+
# Optional: First frame images for additional conditioning
|
| 213 |
+
# If provided, must have one image per prompt
|
| 214 |
+
images: null
|
| 215 |
+
|
| 216 |
+
# Output video dimensions [width, height, frames]
|
| 217 |
+
# Width and height must be divisible by 32
|
| 218 |
+
# Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...)
|
| 219 |
+
video_dims: [ 512, 512, 81 ]
|
| 220 |
+
|
| 221 |
+
# Frame rate for generated videos
|
| 222 |
+
frame_rate: 25.0
|
| 223 |
+
|
| 224 |
+
# Random seed for reproducible validation outputs
|
| 225 |
+
seed: 42
|
| 226 |
+
|
| 227 |
+
# Number of denoising steps for validation inference
|
| 228 |
+
# Higher values = better quality but slower generation
|
| 229 |
+
inference_steps: 30
|
| 230 |
+
|
| 231 |
+
# Generate validation videos every N training steps
|
| 232 |
+
# Set to null to disable validation during training
|
| 233 |
+
interval: 100
|
| 234 |
+
|
| 235 |
+
# Number of videos to generate per prompt
|
| 236 |
+
videos_per_prompt: 1
|
| 237 |
+
|
| 238 |
+
# Classifier-free guidance scale
|
| 239 |
+
# Higher values = stronger adherence to prompt but may introduce artifacts
|
| 240 |
+
guidance_scale: 4.0
|
| 241 |
+
|
| 242 |
+
# STG (Spatio-Temporal Guidance) parameters for improved video quality
|
| 243 |
+
# STG is combined with CFG for better temporal coherence
|
| 244 |
+
stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG)
|
| 245 |
+
stg_blocks: [29] # Recommended: single block 29
|
| 246 |
+
stg_mode: "stg_v" # "stg_v" for video-only (no audio training)
|
| 247 |
+
|
| 248 |
+
# Whether to generate audio in validation samples
|
| 249 |
+
# Can be enabled even when not training the audio branch
|
| 250 |
+
generate_audio: false
|
| 251 |
+
|
| 252 |
+
# Skip validation at the beginning of training (step 0)
|
| 253 |
+
skip_initial_validation: false
|
| 254 |
+
|
| 255 |
+
# Concatenate reference video side-by-side with generated output
|
| 256 |
+
# Useful for visually comparing the transformation quality
|
| 257 |
+
include_reference_in_output: true
|
| 258 |
+
|
| 259 |
+
# -----------------------------------------------------------------------------
|
| 260 |
+
# Checkpoint Configuration
|
| 261 |
+
# -----------------------------------------------------------------------------
|
| 262 |
+
# Controls model checkpoint saving during training.
|
| 263 |
+
checkpoints:
|
| 264 |
+
# Save a checkpoint every N steps
|
| 265 |
+
# Set to null to disable intermediate checkpoints
|
| 266 |
+
interval: 250
|
| 267 |
+
|
| 268 |
+
# Number of most recent checkpoints to keep
|
| 269 |
+
# Set to -1 to keep all checkpoints
|
| 270 |
+
keep_last_n: 3
|
| 271 |
+
|
| 272 |
+
# Precision to use when saving checkpoint weights
|
| 273 |
+
# Options: "bfloat16" (default, smaller files) or "float32" (full precision)
|
| 274 |
+
precision: "bfloat16"
|
| 275 |
+
|
| 276 |
+
# -----------------------------------------------------------------------------
|
| 277 |
+
# Flow Matching Configuration
|
| 278 |
+
# -----------------------------------------------------------------------------
|
| 279 |
+
# Parameters for the flow matching training objective.
|
| 280 |
+
flow_matching:
|
| 281 |
+
# Timestep sampling mode
|
| 282 |
+
# "shifted_logit_normal" is recommended for LTX-2 models
|
| 283 |
+
timestep_sampling_mode: "shifted_logit_normal"
|
| 284 |
+
|
| 285 |
+
# Additional parameters for timestep sampling
|
| 286 |
+
timestep_sampling_params: { }
|
| 287 |
+
|
| 288 |
+
# -----------------------------------------------------------------------------
|
| 289 |
+
# Hugging Face Hub Configuration
|
| 290 |
+
# -----------------------------------------------------------------------------
|
| 291 |
+
# Settings for uploading trained models to the Hugging Face Hub.
|
| 292 |
+
hub:
|
| 293 |
+
# Whether to push the trained model to the Hub
|
| 294 |
+
push_to_hub: false
|
| 295 |
+
|
| 296 |
+
# Repository ID on Hugging Face Hub (e.g., "username/my-ic-lora-model")
|
| 297 |
+
# Required if push_to_hub is true
|
| 298 |
+
hub_model_id: null
|
| 299 |
+
|
| 300 |
+
# -----------------------------------------------------------------------------
|
| 301 |
+
# Weights & Biases Configuration
|
| 302 |
+
# -----------------------------------------------------------------------------
|
| 303 |
+
# Settings for experiment tracking with W&B.
|
| 304 |
+
wandb:
|
| 305 |
+
# Enable W&B logging
|
| 306 |
+
enabled: false
|
| 307 |
+
|
| 308 |
+
# W&B project name
|
| 309 |
+
project: "ltx-2-trainer"
|
| 310 |
+
|
| 311 |
+
# W&B username or team (null uses default account)
|
| 312 |
+
entity: null
|
| 313 |
+
|
| 314 |
+
# Tags to help organize runs
|
| 315 |
+
tags: [ "ltx2", "ic-lora", "video-to-video" ]
|
| 316 |
+
|
| 317 |
+
# Log validation videos to W&B
|
| 318 |
+
log_validation_videos: true
|
| 319 |
+
|
| 320 |
+
# -----------------------------------------------------------------------------
|
| 321 |
+
# General Configuration
|
| 322 |
+
# -----------------------------------------------------------------------------
|
| 323 |
+
# Global settings for the training run.
|
| 324 |
+
|
| 325 |
+
# Random seed for reproducibility
|
| 326 |
+
seed: 42
|
| 327 |
+
|
| 328 |
+
# Directory to save outputs (checkpoints, validation videos, logs)
|
| 329 |
+
output_dir: "outputs/ltx2_v2v_ic_lora"
|
packages/ltx-trainer/docs/configuration-reference.md
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration Reference
|
| 2 |
+
|
| 3 |
+
The trainer uses structured Pydantic models for configuration, making it easy to customize training parameters.
|
| 4 |
+
This guide covers all available configuration options and their usage.
|
| 5 |
+
|
| 6 |
+
## 📋 Overview
|
| 7 |
+
|
| 8 |
+
The main configuration class is [`LtxTrainerConfig`](../src/ltx_trainer/config.py), which includes the following
|
| 9 |
+
sub-configurations:
|
| 10 |
+
|
| 11 |
+
- **ModelConfig**: Base model and training mode settings
|
| 12 |
+
- **LoraConfig**: LoRA training parameters
|
| 13 |
+
- **TrainingStrategyConfig**: Training strategy settings (text-to-video or video-to-video)
|
| 14 |
+
- **OptimizationConfig**: Learning rate, batch sizes, and scheduler settings
|
| 15 |
+
- **AccelerationConfig**: Mixed precision and quantization settings
|
| 16 |
+
- **DataConfig**: Data loading parameters
|
| 17 |
+
- **ValidationConfig**: Validation and inference settings
|
| 18 |
+
- **CheckpointsConfig**: Checkpoint saving frequency and retention settings
|
| 19 |
+
- **HubConfig**: Hugging Face Hub integration settings
|
| 20 |
+
- **WandbConfig**: Weights & Biases logging settings
|
| 21 |
+
- **FlowMatchingConfig**: Timestep sampling parameters
|
| 22 |
+
|
| 23 |
+
## 📄 Example Configuration Files
|
| 24 |
+
|
| 25 |
+
Check out our example configurations in the `configs` directory:
|
| 26 |
+
|
| 27 |
+
- 📄 [Audio-Video LoRA Training](../configs/ltx2_av_lora.yaml) - Joint audio-video generation training
|
| 28 |
+
- 📄 [Audio-Video LoRA Training (Low VRAM)](../configs/ltx2_av_lora_low_vram.yaml) - Memory-optimized config for 32GB
|
| 29 |
+
GPUs (uses 8-bit optimizer, INT8 quantization, and reduced LoRA rank)
|
| 30 |
+
- 📄 [IC-LoRA Training](../configs/ltx2_v2v_ic_lora.yaml) - Video-to-video transformation training
|
| 31 |
+
|
| 32 |
+
## ⚙️ Configuration Sections
|
| 33 |
+
|
| 34 |
+
### ModelConfig
|
| 35 |
+
|
| 36 |
+
Controls the base model and training mode settings.
|
| 37 |
+
|
| 38 |
+
```yaml
|
| 39 |
+
model:
|
| 40 |
+
model_path: "/path/to/ltx-2-model.safetensors" # Local path to model checkpoint
|
| 41 |
+
text_encoder_path: "/path/to/gemma-model" # Path to Gemma text encoder directory
|
| 42 |
+
training_mode: "lora" # "lora" or "full"
|
| 43 |
+
load_checkpoint: null # Path to checkpoint to resume from
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
**Key parameters:**
|
| 47 |
+
|
| 48 |
+
| Parameter | Description |
|
| 49 |
+
|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 50 |
+
| `model_path` | **Required.** Local path to the LTX-2 model checkpoint (`.safetensors` file). URLs are not supported. |
|
| 51 |
+
| `text_encoder_path` | **Required.** Path to the Gemma text encoder model directory. Download from [HuggingFace](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/). |
|
| 52 |
+
| `training_mode` | Training approach - `"lora"` for LoRA training or `"full"` for full-rank fine-tuning. |
|
| 53 |
+
| `load_checkpoint` | Optional path to resume training from a checkpoint file or directory. |
|
| 54 |
+
|
| 55 |
+
> [!NOTE]
|
| 56 |
+
> LTX-2 requires both a model checkpoint and a Gemma text encoder. Both must be local paths.
|
| 57 |
+
|
| 58 |
+
### LoraConfig
|
| 59 |
+
|
| 60 |
+
LoRA-specific fine-tuning parameters (only used when `training_mode: "lora"`).
|
| 61 |
+
|
| 62 |
+
```yaml
|
| 63 |
+
lora:
|
| 64 |
+
rank: 32 # LoRA rank (higher = more parameters)
|
| 65 |
+
alpha: 32 # LoRA alpha scaling factor
|
| 66 |
+
dropout: 0.0 # Dropout probability (0.0-1.0)
|
| 67 |
+
target_modules: # Modules to apply LoRA to
|
| 68 |
+
- "to_k"
|
| 69 |
+
- "to_q"
|
| 70 |
+
- "to_v"
|
| 71 |
+
- "to_out.0"
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
**Key parameters:**
|
| 75 |
+
|
| 76 |
+
| Parameter | Description |
|
| 77 |
+
|------------------|---------------------------------------------------------------------------------|
|
| 78 |
+
| `rank` | LoRA rank - higher values mean more trainable parameters (typical range: 8-128) |
|
| 79 |
+
| `alpha` | Alpha scaling factor - typically set equal to rank |
|
| 80 |
+
| `dropout` | Dropout probability for regularization |
|
| 81 |
+
| `target_modules` | List of transformer modules to apply LoRA adapters to (see below) |
|
| 82 |
+
|
| 83 |
+
#### Understanding Target Modules
|
| 84 |
+
|
| 85 |
+
The LTX-2 transformer has separate attention and feed-forward blocks for video and audio, as well as cross-attention
|
| 86 |
+
modules that enable the two modalities to exchange information. Choosing the right `target_modules` is critical for
|
| 87 |
+
achieving good results, especially when training with audio.
|
| 88 |
+
|
| 89 |
+
**Video-only modules:**
|
| 90 |
+
|
| 91 |
+
| Module Pattern | Description |
|
| 92 |
+
|------------------------------------------------------------|---------------------------------|
|
| 93 |
+
| `attn1.to_k`, `attn1.to_q`, `attn1.to_v`, `attn1.to_out.0` | Video self-attention |
|
| 94 |
+
| `attn2.to_k`, `attn2.to_q`, `attn2.to_v`, `attn2.to_out.0` | Video cross-attention (to text) |
|
| 95 |
+
| `ff.net.0.proj`, `ff.net.2` | Video feed-forward network |
|
| 96 |
+
|
| 97 |
+
**Audio-only modules:**
|
| 98 |
+
|
| 99 |
+
| Module Pattern | Description |
|
| 100 |
+
|------------------------------------------------------------------------------------|---------------------------------|
|
| 101 |
+
| `audio_attn1.to_k`, `audio_attn1.to_q`, `audio_attn1.to_v`, `audio_attn1.to_out.0` | Audio self-attention |
|
| 102 |
+
| `audio_attn2.to_k`, `audio_attn2.to_q`, `audio_attn2.to_v`, `audio_attn2.to_out.0` | Audio cross-attention (to text) |
|
| 103 |
+
| `audio_ff.net.0.proj`, `audio_ff.net.2` | Audio feed-forward network |
|
| 104 |
+
|
| 105 |
+
**Audio-video cross-attention modules:**
|
| 106 |
+
|
| 107 |
+
These modules enable bidirectional information flow between the audio and video modalities:
|
| 108 |
+
|
| 109 |
+
| Module Pattern | Description |
|
| 110 |
+
|--------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------|
|
| 111 |
+
| `audio_to_video_attn.to_k`, `audio_to_video_attn.to_q`, `audio_to_video_attn.to_v`, `audio_to_video_attn.to_out.0` | Video attends to audio (Q from video, K/V from audio) |
|
| 112 |
+
| `video_to_audio_attn.to_k`, `video_to_audio_attn.to_q`, `video_to_audio_attn.to_v`, `video_to_audio_attn.to_out.0` | Audio attends to video (Q from audio, K/V from video) |
|
| 113 |
+
|
| 114 |
+
**Recommended configurations:**
|
| 115 |
+
|
| 116 |
+
For **video-only training**, target the video attention layers:
|
| 117 |
+
|
| 118 |
+
```yaml
|
| 119 |
+
target_modules:
|
| 120 |
+
- "attn1.to_k"
|
| 121 |
+
- "attn1.to_q"
|
| 122 |
+
- "attn1.to_v"
|
| 123 |
+
- "attn1.to_out.0"
|
| 124 |
+
- "attn2.to_k"
|
| 125 |
+
- "attn2.to_q"
|
| 126 |
+
- "attn2.to_v"
|
| 127 |
+
- "attn2.to_out.0"
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
For **audio-video training**, use patterns that match both branches:
|
| 131 |
+
|
| 132 |
+
```yaml
|
| 133 |
+
target_modules:
|
| 134 |
+
- "to_k"
|
| 135 |
+
- "to_q"
|
| 136 |
+
- "to_v"
|
| 137 |
+
- "to_out.0"
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
> [!NOTE]
|
| 141 |
+
> Using shorter patterns like `"to_k"` will match all attention modules including `attn1.to_k`, `audio_attn1.to_k`,
|
| 142 |
+
> `audio_to_video_attn.to_k`, and `video_to_audio_attn.to_k`, effectively training video, audio, and cross-modal
|
| 143 |
+
> attention branches together.
|
| 144 |
+
|
| 145 |
+
> [!TIP]
|
| 146 |
+
> You can also target the feed-forward (FFN) modules (`ff.net.0.proj`, `ff.net.2` for video,
|
| 147 |
+
> `audio_ff.net.0.proj`, `audio_ff.net.2` for audio) to increase the LoRA's capacity and potentially
|
| 148 |
+
> help it capture the target distribution better.
|
| 149 |
+
|
| 150 |
+
### TrainingStrategyConfig
|
| 151 |
+
|
| 152 |
+
Configures the training strategy. The trainer includes two built-in strategies described below.
|
| 153 |
+
For custom use cases, see [Implementing Custom Training Strategies](custom-training-strategies.md).
|
| 154 |
+
|
| 155 |
+
#### Text-to-Video Strategy
|
| 156 |
+
|
| 157 |
+
```yaml
|
| 158 |
+
training_strategy:
|
| 159 |
+
name: "text_to_video"
|
| 160 |
+
first_frame_conditioning_p: 0.1 # Probability of first-frame conditioning
|
| 161 |
+
with_audio: false # Enable joint audio-video training
|
| 162 |
+
audio_latents_dir: "audio_latents" # Directory for audio latents (when with_audio: true)
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
#### Video-to-Video Strategy (IC-LoRA)
|
| 166 |
+
|
| 167 |
+
```yaml
|
| 168 |
+
training_strategy:
|
| 169 |
+
name: "video_to_video"
|
| 170 |
+
first_frame_conditioning_p: 0.1
|
| 171 |
+
reference_latents_dir: "reference_latents" # Directory for reference video latents
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
**Key parameters:**
|
| 175 |
+
|
| 176 |
+
| Parameter | Description |
|
| 177 |
+
|------------------------------|------------------------------------------------------------------|
|
| 178 |
+
| `name` | Strategy type: `"text_to_video"` or `"video_to_video"` |
|
| 179 |
+
| `first_frame_conditioning_p` | Probability of using first frame as conditioning (0.0-1.0) |
|
| 180 |
+
| `with_audio` | (text_to_video only) Enable joint audio-video training |
|
| 181 |
+
| `audio_latents_dir` | (text_to_video only) Directory name for audio latents |
|
| 182 |
+
| `reference_latents_dir` | (video_to_video only) Directory name for reference video latents |
|
| 183 |
+
|
| 184 |
+
### OptimizationConfig
|
| 185 |
+
|
| 186 |
+
Training optimization parameters including learning rates, batch sizes, and schedulers.
|
| 187 |
+
|
| 188 |
+
```yaml
|
| 189 |
+
optimization:
|
| 190 |
+
learning_rate: 1e-4 # Learning rate
|
| 191 |
+
steps: 2000 # Total training steps
|
| 192 |
+
batch_size: 1 # Batch size per GPU
|
| 193 |
+
gradient_accumulation_steps: 1 # Steps to accumulate gradients
|
| 194 |
+
max_grad_norm: 1.0 # Gradient clipping threshold
|
| 195 |
+
optimizer_type: "adamw" # "adamw" or "adamw8bit"
|
| 196 |
+
scheduler_type: "linear" # Scheduler type
|
| 197 |
+
scheduler_params: { } # Additional scheduler parameters
|
| 198 |
+
enable_gradient_checkpointing: true # Memory optimization
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
**Key parameters:**
|
| 202 |
+
|
| 203 |
+
| Parameter | Description |
|
| 204 |
+
|---------------------------------|----------------------------------------------------------------------------------------------|
|
| 205 |
+
| `learning_rate` | Learning rate for optimization (typical range: 1e-5 to 1e-3) |
|
| 206 |
+
| `steps` | Total number of training steps |
|
| 207 |
+
| `batch_size` | Batch size per GPU (reduce if running out of memory) |
|
| 208 |
+
| `gradient_accumulation_steps` | Accumulate gradients over multiple steps |
|
| 209 |
+
| `scheduler_type` | LR scheduler: `"constant"`, `"linear"`, `"cosine"`, `"cosine_with_restarts"`, `"polynomial"` |
|
| 210 |
+
| `enable_gradient_checkpointing` | Trade training speed for GPU memory savings (recommended for large models) |
|
| 211 |
+
|
| 212 |
+
### AccelerationConfig
|
| 213 |
+
|
| 214 |
+
Hardware acceleration and compute optimization settings.
|
| 215 |
+
|
| 216 |
+
```yaml
|
| 217 |
+
acceleration:
|
| 218 |
+
mixed_precision_mode: "bf16" # "no", "fp16", or "bf16"
|
| 219 |
+
quantization: null # Quantization options
|
| 220 |
+
load_text_encoder_in_8bit: false # Load text encoder in 8-bit
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
**Key parameters:**
|
| 224 |
+
|
| 225 |
+
| Parameter | Description |
|
| 226 |
+
|-----------------------------|------------------------------------------------------------------------------------|
|
| 227 |
+
| `mixed_precision_mode` | Precision mode - `"bf16"` recommended for modern GPUs |
|
| 228 |
+
| `quantization` | Model quantization: `null`, `"int8-quanto"`, `"int4-quanto"`, `"fp8-quanto"`, etc. |
|
| 229 |
+
| `load_text_encoder_in_8bit` | Load the Gemma text encoder in 8-bit to save GPU memory |
|
| 230 |
+
|
| 231 |
+
### DataConfig
|
| 232 |
+
|
| 233 |
+
Data loading and processing configuration.
|
| 234 |
+
|
| 235 |
+
```yaml
|
| 236 |
+
data:
|
| 237 |
+
preprocessed_data_root: "/path/to/preprocessed/data" # Path to precomputed dataset
|
| 238 |
+
num_dataloader_workers: 2 # Background data loading workers
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
**Key parameters:**
|
| 242 |
+
|
| 243 |
+
| Parameter | Description |
|
| 244 |
+
|--------------------------|--------------------------------------------------------------------------------------------|
|
| 245 |
+
| `preprocessed_data_root` | Path to your preprocessed dataset (contains `latents/`, `conditions/`, etc.) |
|
| 246 |
+
| `num_dataloader_workers` | Number of parallel data loading processes (0 = synchronous loading, useful when debugging) |
|
| 247 |
+
|
| 248 |
+
### ValidationConfig
|
| 249 |
+
|
| 250 |
+
Validation and inference settings for monitoring training progress.
|
| 251 |
+
|
| 252 |
+
```yaml
|
| 253 |
+
validation:
|
| 254 |
+
prompts: # Validation prompts
|
| 255 |
+
- "A cat playing with a ball"
|
| 256 |
+
- "A dog running in a field"
|
| 257 |
+
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
|
| 258 |
+
images: null # Optional image paths for image-to-video
|
| 259 |
+
reference_videos: null # Reference video paths (IC-LoRA only)
|
| 260 |
+
video_dims: [ 576, 576, 89 ] # Video dimensions [width, height, frames]
|
| 261 |
+
frame_rate: 25.0 # Frame rate for generated videos
|
| 262 |
+
seed: 42 # Random seed for reproducibility
|
| 263 |
+
inference_steps: 30 # Number of inference steps
|
| 264 |
+
interval: 100 # Steps between validation runs
|
| 265 |
+
videos_per_prompt: 1 # Videos generated per prompt
|
| 266 |
+
guidance_scale: 4.0 # CFG guidance strength
|
| 267 |
+
stg_scale: 1.0 # STG guidance strength (0.0 to disable)
|
| 268 |
+
stg_blocks: [ 29 ] # Transformer blocks to perturb for STG
|
| 269 |
+
stg_mode: "stg_av" # "stg_av" or "stg_v" (video only)
|
| 270 |
+
generate_audio: true # Whether to generate audio
|
| 271 |
+
skip_initial_validation: false # Skip validation at step 0
|
| 272 |
+
include_reference_in_output: false # Include reference video side-by-side (IC-LoRA)
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
**Key parameters:**
|
| 276 |
+
|
| 277 |
+
| Parameter | Description |
|
| 278 |
+
|-------------------------------|--------------------------------------------------------------------------------------------------------------------------|
|
| 279 |
+
| `prompts` | List of text prompts for validation video generation |
|
| 280 |
+
| `images` | List of image paths for image-to-video validation (must match number of prompts) |
|
| 281 |
+
| `reference_videos` | List of reference video paths for IC-LoRA validation (must match number of prompts) |
|
| 282 |
+
| `video_dims` | Output dimensions `[width, height, frames]`. Width/height must be divisible by 32, frames must satisfy `frames % 8 == 1` |
|
| 283 |
+
| `interval` | Steps between validation runs (set to `null` to disable) |
|
| 284 |
+
| `guidance_scale` | CFG (Classifier-Free Guidance) scale. Recommended: 4.0 |
|
| 285 |
+
| `stg_scale` | STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Recommended: 1.0 |
|
| 286 |
+
| `stg_blocks` | Transformer blocks to perturb for STG. Recommended: `[29]` (single block) |
|
| 287 |
+
| `stg_mode` | STG mode: `"stg_av"` perturbs both audio and video, `"stg_v"` perturbs video only |
|
| 288 |
+
| `generate_audio` | Whether to generate audio in validation samples |
|
| 289 |
+
| `include_reference_in_output` | For IC-LoRA: concatenate reference video side-by-side with output |
|
| 290 |
+
|
| 291 |
+
### CheckpointsConfig
|
| 292 |
+
|
| 293 |
+
Model checkpointing configuration.
|
| 294 |
+
|
| 295 |
+
```yaml
|
| 296 |
+
checkpoints:
|
| 297 |
+
interval: 250 # Steps between checkpoint saves (null = disabled)
|
| 298 |
+
keep_last_n: 3 # Number of recent checkpoints to retain
|
| 299 |
+
precision: bfloat16 # Precision for saved weights (bfloat16 or float32)
|
| 300 |
+
```
|
| 301 |
+
|
| 302 |
+
**Key parameters:**
|
| 303 |
+
|
| 304 |
+
| Parameter | Description |
|
| 305 |
+
|---------------|-------------------------------------------------------------------------------|
|
| 306 |
+
| `interval` | Steps between intermediate checkpoint saves (set to `null` to disable) |
|
| 307 |
+
| `keep_last_n` | Number of most recent checkpoints to keep (-1 = keep all) |
|
| 308 |
+
| `precision` | Precision for saved checkpoint weights: `"bfloat16"` (default) or `"float32"` |
|
| 309 |
+
|
| 310 |
+
### HubConfig
|
| 311 |
+
|
| 312 |
+
Hugging Face Hub integration for automatic model uploads.
|
| 313 |
+
|
| 314 |
+
```yaml
|
| 315 |
+
hub:
|
| 316 |
+
push_to_hub: false # Enable Hub uploading
|
| 317 |
+
hub_model_id: "username/model-name" # Hub repository ID
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
**Key parameters:**
|
| 321 |
+
|
| 322 |
+
| Parameter | Description |
|
| 323 |
+
|----------------|------------------------------------------------------------------|
|
| 324 |
+
| `push_to_hub` | Whether to automatically push trained models to Hugging Face Hub |
|
| 325 |
+
| `hub_model_id` | Repository ID in format `"username/repository-name"` |
|
| 326 |
+
|
| 327 |
+
### WandbConfig
|
| 328 |
+
|
| 329 |
+
Weights & Biases logging configuration.
|
| 330 |
+
|
| 331 |
+
```yaml
|
| 332 |
+
wandb:
|
| 333 |
+
enabled: false # Enable W&B logging
|
| 334 |
+
project: "ltx-2-trainer" # W&B project name
|
| 335 |
+
entity: null # W&B username or team
|
| 336 |
+
tags: [ ] # Tags for the run
|
| 337 |
+
log_validation_videos: true # Log validation videos to W&B
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
**Key parameters:**
|
| 341 |
+
|
| 342 |
+
| Parameter | Description |
|
| 343 |
+
|-------------------------|--------------------------------------------------|
|
| 344 |
+
| `enabled` | Whether to enable W&B logging |
|
| 345 |
+
| `project` | W&B project name |
|
| 346 |
+
| `entity` | W&B username or team (null uses default account) |
|
| 347 |
+
| `log_validation_videos` | Whether to log validation videos to W&B |
|
| 348 |
+
|
| 349 |
+
### FlowMatchingConfig
|
| 350 |
+
|
| 351 |
+
Flow matching training configuration for timestep sampling.
|
| 352 |
+
|
| 353 |
+
```yaml
|
| 354 |
+
flow_matching:
|
| 355 |
+
timestep_sampling_mode: "shifted_logit_normal" # Timestep sampling strategy
|
| 356 |
+
timestep_sampling_params: { } # Additional sampling parameters
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
**Key parameters:**
|
| 360 |
+
|
| 361 |
+
| Parameter | Description |
|
| 362 |
+
|----------------------------|------------------------------------------------------------|
|
| 363 |
+
| `timestep_sampling_mode` | Sampling strategy: `"uniform"` or `"shifted_logit_normal"` |
|
| 364 |
+
| `timestep_sampling_params` | Additional parameters for the sampling strategy |
|
| 365 |
+
|
| 366 |
+
## 🚀 Next Steps
|
| 367 |
+
|
| 368 |
+
Once you've configured your training parameters:
|
| 369 |
+
|
| 370 |
+
- Set up your dataset using [Dataset Preparation](dataset-preparation.md)
|
| 371 |
+
- Choose your training approach in [Training Modes](training-modes.md)
|
| 372 |
+
- Start training with the [Training Guide](training-guide.md)
|
packages/ltx-trainer/docs/custom-training-strategies.md
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementing Custom Training Strategies
|
| 2 |
+
|
| 3 |
+
This guide explains how to implement your own training strategy for specialized use cases like audio-only training,
|
| 4 |
+
video inpainting, or other custom training recipes.
|
| 5 |
+
|
| 6 |
+
## 📋 Overview
|
| 7 |
+
|
| 8 |
+
The trainer uses the **Strategy Pattern** to separate training logic from the core training loop. Each strategy defines:
|
| 9 |
+
|
| 10 |
+
1. **What data is needed** - Which preprocessed data directories to load
|
| 11 |
+
2. **How to prepare inputs** - Transform batch data into model inputs
|
| 12 |
+
3. **How to compute loss** - Calculate the training objective
|
| 13 |
+
|
| 14 |
+
This architecture lets you implement new training modes without modifying the core trainer code.
|
| 15 |
+
|
| 16 |
+
### When You Need a Custom Strategy
|
| 17 |
+
|
| 18 |
+
Consider implementing a custom strategy when you need:
|
| 19 |
+
|
| 20 |
+
- **Different input modalities** (e.g., audio-only, audio-to-video conditioning)
|
| 21 |
+
- **Additional conditioning signals** (e.g., masks for inpainting, depth maps)
|
| 22 |
+
- **Custom loss computation** (e.g., weighted losses, auxiliary losses)
|
| 23 |
+
- **Different noise application patterns** (e.g., partial masking)
|
| 24 |
+
|
| 25 |
+
## 🏗️ Architecture Overview
|
| 26 |
+
|
| 27 |
+
### How Strategies Fit Into the Trainer
|
| 28 |
+
|
| 29 |
+
The trainer delegates all training-mode-specific logic to the strategy:
|
| 30 |
+
|
| 31 |
+
1. **Initialization** — The trainer calls `get_data_sources()` to determine which preprocessed data directories to load
|
| 32 |
+
2. **Each training step:**
|
| 33 |
+
- Calls `prepare_training_inputs()` to transform the raw batch into model-ready inputs
|
| 34 |
+
- Runs the transformer forward pass
|
| 35 |
+
- Calls `compute_loss()` to compute the training objective
|
| 36 |
+
|
| 37 |
+
The trainer handles everything else: optimization, checkpointing, validation, and distributed training.
|
| 38 |
+
|
| 39 |
+
### Key Components
|
| 40 |
+
|
| 41 |
+
| Component | Purpose |
|
| 42 |
+
|-----------------------------------------------------------------------------------------|--------------------------------------------------------------|
|
| 43 |
+
| [`TrainingStrategyConfigBase`](../src/ltx_trainer/training_strategies/base_strategy.py) | Base class for strategy configuration (Pydantic model) |
|
| 44 |
+
| [`TrainingStrategy`](../src/ltx_trainer/training_strategies/base_strategy.py) | Abstract base class defining the strategy interface |
|
| 45 |
+
| [`ModelInputs`](../src/ltx_trainer/training_strategies/base_strategy.py) | Dataclass containing prepared inputs for the transformer |
|
| 46 |
+
| [`Modality`](../../ltx-core/src/ltx_core/model/transformer/modality.py) | ltx-core dataclass representing video or audio modality data |
|
| 47 |
+
|
| 48 |
+
## 📝 Step-by-Step Implementation
|
| 49 |
+
|
| 50 |
+
### Step 1: Plan Your Strategy
|
| 51 |
+
|
| 52 |
+
Before writing code, answer these questions:
|
| 53 |
+
|
| 54 |
+
1. **What additional data does your strategy need?**
|
| 55 |
+
- Example: Inpainting needs mask latents alongside video latents
|
| 56 |
+
- Example: Audio-to-video needs reference audio embeddings
|
| 57 |
+
|
| 58 |
+
2. **What does conditioning look like?**
|
| 59 |
+
- Which tokens should be noised vs. kept clean?
|
| 60 |
+
- How should conditioning tokens be structured (e.g., first frame, reference video, mask)?
|
| 61 |
+
|
| 62 |
+
3. **How should loss be computed?**
|
| 63 |
+
- Which tokens contribute to the loss?
|
| 64 |
+
- Are there multiple loss terms to combine?
|
| 65 |
+
|
| 66 |
+
### Step 2: Extend Data Preprocessing (If Needed)
|
| 67 |
+
|
| 68 |
+
If your strategy requires additional preprocessed data beyond video latents, audio latents, and text embeddings, you'll
|
| 69 |
+
need to extend the preprocessing pipeline.
|
| 70 |
+
|
| 71 |
+
#### Option A: Modify `process_dataset.py`
|
| 72 |
+
|
| 73 |
+
For integrated preprocessing, add new arguments and processing steps to the main script. For example, to add mask
|
| 74 |
+
preprocessing:
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
# In process_dataset.py, add a new argument
|
| 78 |
+
@app.command()
|
| 79 |
+
def main(
|
| 80 |
+
# ... existing arguments ...
|
| 81 |
+
mask_column: str | None = typer.Option(
|
| 82 |
+
default=None,
|
| 83 |
+
help="Column name containing mask video paths (for inpainting)",
|
| 84 |
+
),
|
| 85 |
+
) -> None:
|
| 86 |
+
# ... existing processing ...
|
| 87 |
+
|
| 88 |
+
# Process masks if provided
|
| 89 |
+
if mask_column:
|
| 90 |
+
logger.info("Processing mask videos for inpainting training...")
|
| 91 |
+
mask_latents_dir = output_base / "mask_latents"
|
| 92 |
+
|
| 93 |
+
compute_latents(
|
| 94 |
+
dataset_file=dataset_path,
|
| 95 |
+
video_column=mask_column,
|
| 96 |
+
resolution_buckets=parsed_resolution_buckets,
|
| 97 |
+
output_dir=str(mask_latents_dir),
|
| 98 |
+
model_path=model_path,
|
| 99 |
+
# ... other args ...
|
| 100 |
+
)
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
#### Option B: Create a Standalone Script
|
| 104 |
+
|
| 105 |
+
For complex preprocessing that doesn't fit naturally into the existing pipeline, create a dedicated script
|
| 106 |
+
(e.g., `scripts/process_masks.py`). Use [`scripts/compute_reference.py`](../scripts/compute_reference.py) as a
|
| 107 |
+
template - it shows how to process paired data and update the dataset JSON.
|
| 108 |
+
|
| 109 |
+
#### Expected Output Structure
|
| 110 |
+
|
| 111 |
+
Your preprocessing should create a directory structure that the strategy can reference:
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
preprocessed_data_root/
|
| 115 |
+
├── latents/ # Video latents (standard)
|
| 116 |
+
├── conditions/ # Text embeddings (standard)
|
| 117 |
+
├── audio_latents/ # Audio latents (if with_audio)
|
| 118 |
+
├── mask_latents/ # Your custom data directory
|
| 119 |
+
└── reference_latents/ # Reference videos (for IC-LoRA)
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Step 3: Create the Strategy Configuration
|
| 123 |
+
|
| 124 |
+
Create a new file for your strategy (e.g., `src/ltx_trainer/training_strategies/inpainting.py`):
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
"""Inpainting training strategy.
|
| 128 |
+
|
| 129 |
+
This strategy implements video inpainting training where:
|
| 130 |
+
- Mask latents indicate which regions to inpaint
|
| 131 |
+
- Loss is computed only on masked (inpainted) regions
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
from typing import Any, Literal
|
| 135 |
+
|
| 136 |
+
import torch
|
| 137 |
+
from pydantic import Field
|
| 138 |
+
from torch import Tensor
|
| 139 |
+
|
| 140 |
+
from ltx_core.model.transformer.modality import Modality
|
| 141 |
+
from ltx_trainer.timestep_samplers import TimestepSampler
|
| 142 |
+
from ltx_trainer.training_strategies.base_strategy import (
|
| 143 |
+
ModelInputs,
|
| 144 |
+
TrainingStrategy,
|
| 145 |
+
TrainingStrategyConfigBase,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class InpaintingConfig(TrainingStrategyConfigBase):
|
| 150 |
+
"""Configuration for inpainting training strategy."""
|
| 151 |
+
|
| 152 |
+
# The 'name' field acts as a discriminator for the config union
|
| 153 |
+
name: Literal["inpainting"] = "inpainting"
|
| 154 |
+
|
| 155 |
+
mask_latents_dir: str = Field(
|
| 156 |
+
default="mask_latents",
|
| 157 |
+
description="Directory name for mask latents",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Add any strategy-specific parameters
|
| 161 |
+
mask_threshold: float = Field(
|
| 162 |
+
default=0.5,
|
| 163 |
+
description="Threshold for binary mask conversion",
|
| 164 |
+
ge=0.0,
|
| 165 |
+
le=1.0,
|
| 166 |
+
)
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
**Key points:**
|
| 170 |
+
|
| 171 |
+
- Inherit from `TrainingStrategyConfigBase`
|
| 172 |
+
- Use `Literal["your_strategy_name"]` for the `name` field - this enables automatic strategy selection
|
| 173 |
+
- Use Pydantic `Field` for validation and documentation
|
| 174 |
+
|
| 175 |
+
### Step 4: Implement the Strategy Class
|
| 176 |
+
|
| 177 |
+
```python
|
| 178 |
+
class InpaintingStrategy(TrainingStrategy):
|
| 179 |
+
"""Inpainting training strategy.
|
| 180 |
+
|
| 181 |
+
Trains the model to fill in masked regions of videos while
|
| 182 |
+
keeping unmasked regions as conditioning.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
config: InpaintingConfig
|
| 186 |
+
|
| 187 |
+
def __init__(self, config: InpaintingConfig):
|
| 188 |
+
super().__init__(config)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def requires_audio(self) -> bool:
|
| 192 |
+
"""Whether this strategy requires audio components."""
|
| 193 |
+
return False # Set to True if your strategy needs audio
|
| 194 |
+
|
| 195 |
+
def get_data_sources(self) -> dict[str, str]:
|
| 196 |
+
"""Define which data directories to load.
|
| 197 |
+
|
| 198 |
+
Returns a mapping of directory names to batch keys.
|
| 199 |
+
The trainer will load .pt files from each directory and
|
| 200 |
+
make them available in the batch under the specified key.
|
| 201 |
+
"""
|
| 202 |
+
return {
|
| 203 |
+
"latents": "latents", # -> batch["latents"]
|
| 204 |
+
"conditions": "conditions", # -> batch["conditions"]
|
| 205 |
+
self.config.mask_latents_dir: "masks", # -> batch["masks"]
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
def prepare_training_inputs(
|
| 209 |
+
self,
|
| 210 |
+
batch: dict[str, Any],
|
| 211 |
+
timestep_sampler: TimestepSampler,
|
| 212 |
+
) -> ModelInputs:
|
| 213 |
+
"""Transform batch data into model inputs.
|
| 214 |
+
|
| 215 |
+
This is where the core training logic lives:
|
| 216 |
+
1. Extract and patchify latents
|
| 217 |
+
2. Sample noise and apply it appropriately
|
| 218 |
+
3. Create conditioning masks
|
| 219 |
+
4. Build Modality objects for the transformer
|
| 220 |
+
"""
|
| 221 |
+
# Get video latents [B, C, F, H, W]
|
| 222 |
+
latents_data = batch["latents"]
|
| 223 |
+
video_latents = latents_data["latents"]
|
| 224 |
+
|
| 225 |
+
# Get dimensions
|
| 226 |
+
num_frames = latents_data["num_frames"][0].item()
|
| 227 |
+
height = latents_data["height"][0].item()
|
| 228 |
+
width = latents_data["width"][0].item()
|
| 229 |
+
|
| 230 |
+
# Patchify: [B, C, F, H, W] -> [B, seq_len, C]
|
| 231 |
+
video_latents = self._video_patchifier.patchify(video_latents)
|
| 232 |
+
|
| 233 |
+
batch_size, seq_len, _ = video_latents.shape
|
| 234 |
+
device = video_latents.device
|
| 235 |
+
dtype = video_latents.dtype
|
| 236 |
+
|
| 237 |
+
# Get mask latents and process them
|
| 238 |
+
mask_data = batch["masks"]
|
| 239 |
+
mask_latents = mask_data["latents"]
|
| 240 |
+
mask_latents = self._video_patchifier.patchify(mask_latents)
|
| 241 |
+
|
| 242 |
+
# Create binary mask: True = inpaint this region, False = keep original
|
| 243 |
+
inpaint_mask = mask_latents.mean(dim=-1) > self.config.mask_threshold
|
| 244 |
+
|
| 245 |
+
# Sample noise and sigmas
|
| 246 |
+
sigmas = timestep_sampler.sample_for(video_latents)
|
| 247 |
+
noise = torch.randn_like(video_latents)
|
| 248 |
+
|
| 249 |
+
# Apply noise only to inpaint regions
|
| 250 |
+
sigmas_expanded = sigmas.view(-1, 1, 1)
|
| 251 |
+
noisy_latents = (1 - sigmas_expanded) * video_latents + sigmas_expanded * noise
|
| 252 |
+
|
| 253 |
+
# Keep original latents for non-inpaint regions (conditioning)
|
| 254 |
+
inpaint_mask_expanded = inpaint_mask.unsqueeze(-1)
|
| 255 |
+
noisy_latents = torch.where(inpaint_mask_expanded, noisy_latents, video_latents)
|
| 256 |
+
|
| 257 |
+
# Create per-token timesteps
|
| 258 |
+
# Conditioning tokens (non-inpaint) get timestep=0
|
| 259 |
+
# Inpaint tokens get the sampled sigma
|
| 260 |
+
timesteps = self._create_per_token_timesteps(~inpaint_mask, sigmas.squeeze())
|
| 261 |
+
|
| 262 |
+
# Compute targets (velocity prediction: noise - clean)
|
| 263 |
+
targets = noise - video_latents
|
| 264 |
+
|
| 265 |
+
# Get text embeddings
|
| 266 |
+
conditions = batch["conditions"]
|
| 267 |
+
video_prompt_embeds = conditions["video_prompt_embeds"]
|
| 268 |
+
prompt_attention_mask = conditions["prompt_attention_mask"]
|
| 269 |
+
|
| 270 |
+
# Generate position embeddings
|
| 271 |
+
positions = self._get_video_positions(
|
| 272 |
+
num_frames=num_frames,
|
| 273 |
+
height=height,
|
| 274 |
+
width=width,
|
| 275 |
+
batch_size=batch_size,
|
| 276 |
+
fps=24.0, # Or get from latents_data
|
| 277 |
+
device=device,
|
| 278 |
+
dtype=dtype,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Create video Modality
|
| 282 |
+
video_modality = Modality(
|
| 283 |
+
enabled=True,
|
| 284 |
+
latent=noisy_latents,
|
| 285 |
+
sigma=sigmas,
|
| 286 |
+
timesteps=timesteps,
|
| 287 |
+
positions=positions,
|
| 288 |
+
context=video_prompt_embeds,
|
| 289 |
+
context_mask=prompt_attention_mask,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Loss mask: only compute loss on inpaint regions
|
| 293 |
+
loss_mask = inpaint_mask
|
| 294 |
+
|
| 295 |
+
return ModelInputs(
|
| 296 |
+
video=video_modality,
|
| 297 |
+
audio=None,
|
| 298 |
+
video_targets=targets,
|
| 299 |
+
audio_targets=None,
|
| 300 |
+
video_loss_mask=loss_mask,
|
| 301 |
+
audio_loss_mask=None,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def compute_loss(
|
| 305 |
+
self,
|
| 306 |
+
video_pred: Tensor,
|
| 307 |
+
audio_pred: Tensor | None,
|
| 308 |
+
inputs: ModelInputs,
|
| 309 |
+
) -> Tensor:
|
| 310 |
+
"""Compute training loss on inpaint regions only."""
|
| 311 |
+
# MSE loss
|
| 312 |
+
loss = (video_pred - inputs.video_targets).pow(2)
|
| 313 |
+
|
| 314 |
+
# Apply loss mask
|
| 315 |
+
loss_mask = inputs.video_loss_mask.unsqueeze(-1).float()
|
| 316 |
+
loss = loss.mul(loss_mask).div(loss_mask.mean() + 1e-8)
|
| 317 |
+
|
| 318 |
+
return loss.mean()
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
### Step 5: Register the Strategy
|
| 322 |
+
|
| 323 |
+
You need to register your strategy in two places:
|
| 324 |
+
|
| 325 |
+
**1. Update [`src/ltx_trainer/training_strategies/__init__.py`](../src/ltx_trainer/training_strategies/__init__.py):**
|
| 326 |
+
|
| 327 |
+
```python
|
| 328 |
+
# Add import for your strategy
|
| 329 |
+
from ltx_trainer.training_strategies.inpainting import InpaintingConfig, InpaintingStrategy
|
| 330 |
+
|
| 331 |
+
# Add to the TrainingStrategyConfig type alias
|
| 332 |
+
TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig | InpaintingConfig
|
| 333 |
+
|
| 334 |
+
# Add to __all__
|
| 335 |
+
__all__ = [
|
| 336 |
+
# ... existing exports ...
|
| 337 |
+
"InpaintingConfig",
|
| 338 |
+
"InpaintingStrategy",
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Add case in get_training_strategy()
|
| 343 |
+
def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy:
|
| 344 |
+
match config:
|
| 345 |
+
# ... existing cases ...
|
| 346 |
+
case InpaintingConfig():
|
| 347 |
+
strategy = InpaintingStrategy(config)
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
**2. Update [`src/ltx_trainer/config.py`](../src/ltx_trainer/config.py):**
|
| 351 |
+
|
| 352 |
+
```python
|
| 353 |
+
# Add import
|
| 354 |
+
from ltx_trainer.training_strategies.inpainting import InpaintingConfig
|
| 355 |
+
|
| 356 |
+
# Add to the TrainingStrategyConfig union with a Tag matching your strategy name
|
| 357 |
+
TrainingStrategyConfig = Annotated[
|
| 358 |
+
Annotated[TextToVideoConfig, Tag("text_to_video")]
|
| 359 |
+
| Annotated[VideoToVideoConfig, Tag("video_to_video")]
|
| 360 |
+
| Annotated[InpaintingConfig, Tag("inpainting")], # Add your config
|
| 361 |
+
Discriminator(_get_strategy_discriminator),
|
| 362 |
+
]
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
### Step 6: Create a Configuration File
|
| 366 |
+
|
| 367 |
+
Create an example config in `configs/`:
|
| 368 |
+
|
| 369 |
+
```yaml
|
| 370 |
+
# configs/ltx2_inpainting_lora.yaml
|
| 371 |
+
|
| 372 |
+
model:
|
| 373 |
+
model_path: "/path/to/ltx2.safetensors"
|
| 374 |
+
text_encoder_path: "/path/to/gemma"
|
| 375 |
+
training_mode: "lora"
|
| 376 |
+
|
| 377 |
+
training_strategy:
|
| 378 |
+
name: "inpainting" # Must match your Literal type
|
| 379 |
+
mask_latents_dir: "mask_latents"
|
| 380 |
+
mask_threshold: 0.5
|
| 381 |
+
|
| 382 |
+
lora:
|
| 383 |
+
rank: 32
|
| 384 |
+
alpha: 32
|
| 385 |
+
target_modules:
|
| 386 |
+
- "to_k"
|
| 387 |
+
- "to_q"
|
| 388 |
+
- "to_v"
|
| 389 |
+
- "to_out.0"
|
| 390 |
+
|
| 391 |
+
data:
|
| 392 |
+
preprocessed_data_root: "/path/to/preprocessed/dataset"
|
| 393 |
+
|
| 394 |
+
optimization:
|
| 395 |
+
learning_rate: 1e-4
|
| 396 |
+
steps: 2000
|
| 397 |
+
batch_size: 1
|
| 398 |
+
|
| 399 |
+
# ... other config sections ...
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
## 🔧 Helper Methods Reference
|
| 403 |
+
|
| 404 |
+
The base `TrainingStrategy` class provides these helper methods:
|
| 405 |
+
|
| 406 |
+
| Method | Purpose |
|
| 407 |
+
|----------------------------------------------|-------------------------------------------------|
|
| 408 |
+
| `_video_patchifier.patchify(latents)` | Convert `[B, C, F, H, W]` → `[B, seq_len, C]` |
|
| 409 |
+
| `_audio_patchifier.patchify(latents)` | Convert `[B, C, T, F]` → `[B, T, C*F]` |
|
| 410 |
+
| `_get_video_positions(...)` | Generate position embeddings for video |
|
| 411 |
+
| `_get_audio_positions(...)` | Generate position embeddings for audio |
|
| 412 |
+
| `_create_per_token_timesteps(mask, sigma)` | Create timesteps with 0 for conditioning tokens |
|
| 413 |
+
| `_create_first_frame_conditioning_mask(...)` | Create mask for first-frame conditioning |
|
| 414 |
+
|
| 415 |
+
## 📊 Understanding ModelInputs
|
| 416 |
+
|
| 417 |
+
The `ModelInputs` dataclass contains everything needed for the forward pass and loss computation:
|
| 418 |
+
|
| 419 |
+
```python
|
| 420 |
+
@dataclass
|
| 421 |
+
class ModelInputs:
|
| 422 |
+
video: Modality # Video modality data
|
| 423 |
+
audio: Modality | None # Audio modality (None if video-only)
|
| 424 |
+
|
| 425 |
+
video_targets: Tensor # Target values for loss (velocity)
|
| 426 |
+
audio_targets: Tensor | None
|
| 427 |
+
|
| 428 |
+
video_loss_mask: Tensor # Boolean: True = compute loss for this token
|
| 429 |
+
audio_loss_mask: Tensor | None
|
| 430 |
+
|
| 431 |
+
ref_seq_len: int | None = None # For IC-LoRA: reference sequence length
|
| 432 |
+
```
|
| 433 |
+
|
| 434 |
+
## 📊 Understanding Modality
|
| 435 |
+
|
| 436 |
+
The `Modality` dataclass (from ltx-core) represents a single modality's data:
|
| 437 |
+
|
| 438 |
+
```python
|
| 439 |
+
@dataclass(frozen=True)
|
| 440 |
+
class Modality:
|
| 441 |
+
enabled: bool # Whether this modality is active
|
| 442 |
+
latent: Tensor # [B, seq_len, C] - the latent tokens
|
| 443 |
+
timesteps: Tensor # [B, seq_len] - per-token timesteps (sigmas)
|
| 444 |
+
positions: Tensor # [B, dims, seq_len, 2] - position bounds
|
| 445 |
+
context: Tensor # [B, ctx_len, C] - text embeddings
|
| 446 |
+
context_mask: Tensor # [B, ctx_len] - attention mask for context
|
| 447 |
+
```
|
| 448 |
+
|
| 449 |
+
> [!NOTE]
|
| 450 |
+
> **Per-token timesteps:** Each token in the sequence has its own timestep. Conditioning tokens—those that should remain
|
| 451 |
+
> un-noised—must have `timestep=0`. This is how the model distinguishes clean reference tokens from tokens to denoise. Use
|
| 452 |
+
`_create_per_token_timesteps(conditioning_mask, sigma)` to set this up correctly.
|
| 453 |
+
|
| 454 |
+
> [!NOTE]
|
| 455 |
+
> `Modality` is immutable (frozen dataclass). Use `dataclasses.replace()` to create modified copies.
|
| 456 |
+
|
| 457 |
+
## ✅ Testing Your Strategy
|
| 458 |
+
|
| 459 |
+
1. **Verify your training configuration is valid:**
|
| 460 |
+
```bash
|
| 461 |
+
uv run python -c "
|
| 462 |
+
from ltx_trainer.config import LtxTrainerConfig
|
| 463 |
+
import yaml
|
| 464 |
+
|
| 465 |
+
with open('configs/ltx2_inpainting_lora.yaml') as f:
|
| 466 |
+
config = LtxTrainerConfig(**yaml.safe_load(f))
|
| 467 |
+
print(f'Strategy: {config.training_strategy.name}')
|
| 468 |
+
"
|
| 469 |
+
```
|
| 470 |
+
|
| 471 |
+
2. **Test strategy instantiation:**
|
| 472 |
+
```bash
|
| 473 |
+
uv run python -c "
|
| 474 |
+
from ltx_trainer.training_strategies import get_training_strategy
|
| 475 |
+
from ltx_trainer.training_strategies.inpainting import InpaintingConfig
|
| 476 |
+
|
| 477 |
+
config = InpaintingConfig()
|
| 478 |
+
strategy = get_training_strategy(config)
|
| 479 |
+
print(f'Data sources: {strategy.get_data_sources()}')
|
| 480 |
+
"
|
| 481 |
+
```
|
| 482 |
+
|
| 483 |
+
3. **Run a short training test:**
|
| 484 |
+
```bash
|
| 485 |
+
uv run python scripts/train.py configs/ltx2_inpainting_lora.yaml
|
| 486 |
+
```
|
| 487 |
+
|
| 488 |
+
## 💡 Tips and Best Practices
|
| 489 |
+
|
| 490 |
+
### Debugging
|
| 491 |
+
|
| 492 |
+
- Set `data.num_dataloader_workers: 0` to get clearer error messages
|
| 493 |
+
- Use a small dataset and few steps for initial testing
|
| 494 |
+
- Check tensor shapes at each step with print statements
|
| 495 |
+
|
| 496 |
+
## 🔗 Related Documentation
|
| 497 |
+
|
| 498 |
+
- [Training Modes](training-modes.md) - Overview of built-in training modes
|
| 499 |
+
- [Configuration Reference](configuration-reference.md) - All configuration options
|
| 500 |
+
- [Dataset Preparation](dataset-preparation.md) - Preprocessing workflow
|
| 501 |
+
- [ltx-core Documentation](../../ltx-core/README.md) - Core model components
|
| 502 |
+
|
| 503 |
+
## 📚 Reference: Existing Strategies
|
| 504 |
+
|
| 505 |
+
Study these implementations for guidance:
|
| 506 |
+
|
| 507 |
+
| Strategy | Complexity | Key Features |
|
| 508 |
+
|------------------------------------------------------------------------------------|------------|------------------------------------------------|
|
| 509 |
+
| [`TextToVideoStrategy`](../src/ltx_trainer/training_strategies/text_to_video.py) | Simple | First-frame conditioning, optional audio |
|
| 510 |
+
| [`VideoToVideoStrategy`](../src/ltx_trainer/training_strategies/video_to_video.py) | Medium | Reference video concatenation, split loss mask |
|
packages/ltx-trainer/docs/dataset-preparation.md
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Preparation Guide
|
| 2 |
+
|
| 3 |
+
This guide covers the complete workflow for preparing and preprocessing your dataset for training.
|
| 4 |
+
|
| 5 |
+
## 📋 Overview
|
| 6 |
+
|
| 7 |
+
The general dataset preparation workflow is:
|
| 8 |
+
|
| 9 |
+
1. **(Optional)** Split long videos into scenes using `split_scenes.py`
|
| 10 |
+
2. **(Optional)** Generate captions for your videos using `caption_videos.py`
|
| 11 |
+
3. **Preprocess your dataset** using `process_dataset.py` to compute and cache video/audio latents and text embeddings
|
| 12 |
+
4. **Run the trainer** with your preprocessed dataset
|
| 13 |
+
|
| 14 |
+
## 🎬 Step 1: Split Scenes
|
| 15 |
+
|
| 16 |
+
If you're starting with raw, long-form videos (e.g., downloaded from YouTube), you should first split them into shorter, coherent scenes.
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
uv run python scripts/split_scenes.py input.mp4 scenes_output_dir/ \
|
| 20 |
+
--filter-shorter-than 5s
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
This will create multiple video clips in `scenes_output_dir`.
|
| 24 |
+
These clips will be the input for the captioning step, if you choose to use it.
|
| 25 |
+
|
| 26 |
+
The script supports many configuration options for scene detection (detector algorithms, thresholds, minimum scene lengths, etc.):
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
uv run python scripts/split_scenes.py --help
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## 📝 Step 2: Caption Videos
|
| 33 |
+
|
| 34 |
+
If your dataset doesn't include captions, you can automatically generate them using multimodal models that understand both video and audio.
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
uv run python scripts/caption_videos.py scenes_output_dir/ \
|
| 38 |
+
--output scenes_output_dir/dataset.json
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
If you're running into VRAM issues, try enabling 8-bit quantization to reduce memory usage:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
uv run python scripts/caption_videos.py scenes_output_dir/ \
|
| 45 |
+
--output scenes_output_dir/dataset.json \
|
| 46 |
+
--use-8bit
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
This will create a `dataset.json` file containing video paths and their captions.
|
| 50 |
+
|
| 51 |
+
**Captioning options:**
|
| 52 |
+
|
| 53 |
+
| Option | Description |
|
| 54 |
+
|--------|-------------|
|
| 55 |
+
| `--captioner-type` | `qwen_omni` (default, local) or `gemini_flash` (API) |
|
| 56 |
+
| `--use-8bit` | Enable 8-bit quantization for lower VRAM usage |
|
| 57 |
+
| `--no-audio` | Disable audio processing (video-only captions) |
|
| 58 |
+
| `--override` | Re-caption files that already have captions |
|
| 59 |
+
| `--api-key` | API key for Gemini Flash (or set `GOOGLE_API_KEY` env var) |
|
| 60 |
+
|
| 61 |
+
**Caption format:**
|
| 62 |
+
|
| 63 |
+
The captioner produces structured captions with sections for:
|
| 64 |
+
- **Visual content**: People, objects, actions, settings, colors, movements
|
| 65 |
+
- **Speech transcription**: Word-for-word transcription of spoken content
|
| 66 |
+
- **Sounds**: Music, ambient sounds, sound effects
|
| 67 |
+
- **On-screen text**: Any visible text overlays
|
| 68 |
+
|
| 69 |
+
> [!NOTE]
|
| 70 |
+
> The automatically generated captions may contain inaccuracies or hallucinated content.
|
| 71 |
+
> We recommend reviewing and correcting the generated captions in your `dataset.json` file before proceeding to preprocessing.
|
| 72 |
+
|
| 73 |
+
## ⚡ Step 3: Dataset Preprocessing
|
| 74 |
+
|
| 75 |
+
This step preprocesses your video dataset by:
|
| 76 |
+
|
| 77 |
+
1. Resizing and cropping videos to fit specified resolution buckets
|
| 78 |
+
2. Computing and caching video latent representations
|
| 79 |
+
3. Computing and caching text embeddings for captions
|
| 80 |
+
4. (Optional) Computing and caching audio latents
|
| 81 |
+
|
| 82 |
+
> [!WARNING]
|
| 83 |
+
> Very large videos (especially high spatial resolution and/or many frames) can cause GPU out-of-memory (OOM)
|
| 84 |
+
> during preprocessing/encoding.
|
| 85 |
+
> The simplest fix is to reduce the target resolution (spatially: width/height) and/or the number of frames
|
| 86 |
+
> (temporally) by using `--resolution-buckets` with smaller dimensions (lower width/height and/or fewer frames).
|
| 87 |
+
|
| 88 |
+
### Basic Usage
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 92 |
+
--resolution-buckets "960x544x49" \
|
| 93 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 94 |
+
--text-encoder-path /path/to/gemma-model
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### With Audio Processing
|
| 98 |
+
|
| 99 |
+
For audio-video training, add the `--with-audio` flag:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 103 |
+
--resolution-buckets "960x544x49" \
|
| 104 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 105 |
+
--text-encoder-path /path/to/gemma-model \
|
| 106 |
+
--with-audio
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### 📊 Dataset Format
|
| 110 |
+
|
| 111 |
+
The trainer supports either videos or single images.
|
| 112 |
+
Note that your dataset must be homogeneous - either all videos or all images, mixing is not supported.
|
| 113 |
+
|
| 114 |
+
> [!TIP]
|
| 115 |
+
> **Image Datasets:** When using images, follow the same preprocessing steps and format requirements as with videos,
|
| 116 |
+
> but use `1` for the frame count in the resolution bucket (e.g., `960x544x1`).
|
| 117 |
+
|
| 118 |
+
The dataset must be a CSV, JSON, or JSONL metadata file with columns for captions and video paths:
|
| 119 |
+
|
| 120 |
+
**JSON format example:**
|
| 121 |
+
|
| 122 |
+
```json
|
| 123 |
+
[
|
| 124 |
+
{
|
| 125 |
+
"caption": "A cat playing with a ball of yarn",
|
| 126 |
+
"media_path": "videos/cat_playing.mp4"
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"caption": "A dog running in the park",
|
| 130 |
+
"media_path": "videos/dog_running.mp4"
|
| 131 |
+
}
|
| 132 |
+
]
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
**JSONL format example:**
|
| 136 |
+
|
| 137 |
+
```jsonl
|
| 138 |
+
{"caption": "A cat playing with a ball of yarn", "media_path": "videos/cat_playing.mp4"}
|
| 139 |
+
{"caption": "A dog running in the park", "media_path": "videos/dog_running.mp4"}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
**CSV format example:**
|
| 143 |
+
|
| 144 |
+
```csv
|
| 145 |
+
caption,media_path
|
| 146 |
+
"A cat playing with a ball of yarn","videos/cat_playing.mp4"
|
| 147 |
+
"A dog running in the park","videos/dog_running.mp4"
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### 📐 Resolution Buckets
|
| 151 |
+
|
| 152 |
+
Videos are organized into "buckets" of specific dimensions (width × height × frames).
|
| 153 |
+
Each video is assigned to the nearest matching bucket.
|
| 154 |
+
You can preprocess with one or multiple resolution buckets.
|
| 155 |
+
When training with multiple resolution buckets, you must use a batch size of 1.
|
| 156 |
+
|
| 157 |
+
The dimensions of each bucket must follow these constraints due to LTX-2's VAE architecture:
|
| 158 |
+
|
| 159 |
+
- **Spatial dimensions** (width and height) must be multiples of 32
|
| 160 |
+
- **Number of frames** must satisfy `frames % 8 == 1` (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121, etc.)
|
| 161 |
+
|
| 162 |
+
**Guidelines for choosing training resolution:**
|
| 163 |
+
|
| 164 |
+
- For high-quality, detailed videos: use larger spatial dimensions (e.g. 768x448) with fewer frames (e.g. 89)
|
| 165 |
+
- For longer, motion-focused videos: use smaller spatial dimensions (512×512) with more frames (121)
|
| 166 |
+
- Memory usage increases with both spatial and temporal dimensions
|
| 167 |
+
|
| 168 |
+
**Example usage:**
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 172 |
+
--resolution-buckets "960x544x49" \
|
| 173 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 174 |
+
--text-encoder-path /path/to/gemma-model
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Multiple buckets are supported by separating entries with `;`:
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 181 |
+
--resolution-buckets "960x544x49;512x512x49" \
|
| 182 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 183 |
+
--text-encoder-path /path/to/gemma-model
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
**Video processing workflow:**
|
| 187 |
+
|
| 188 |
+
1. Videos are **resized** maintaining aspect ratio until either width or height matches the target
|
| 189 |
+
2. The larger dimension is **center cropped** to match the bucket's dimensions
|
| 190 |
+
3. Only the **first X frames are taken** to match the bucket's frame count, remaining frames are ignored
|
| 191 |
+
|
| 192 |
+
> [!NOTE]
|
| 193 |
+
> The sequence length processed by the transformer model can be calculated as:
|
| 194 |
+
>
|
| 195 |
+
> ```
|
| 196 |
+
> sequence_length = (H/32) * (W/32) * ((F-1)/8 + 1)
|
| 197 |
+
> ```
|
| 198 |
+
>
|
| 199 |
+
> Where:
|
| 200 |
+
> - H = Height of video
|
| 201 |
+
> - W = Width of video
|
| 202 |
+
> - F = Number of frames
|
| 203 |
+
> - 32 = VAE's spatial downsampling factor
|
| 204 |
+
> - 8 = VAE's temporal downsampling factor
|
| 205 |
+
>
|
| 206 |
+
> For example, a 768×448×89 video would have sequence length:
|
| 207 |
+
> ```
|
| 208 |
+
> (768/32) * (448/32) * ((89-1)/8 + 1) = 24 * 14 * 12 = 4,032
|
| 209 |
+
> ```
|
| 210 |
+
>
|
| 211 |
+
> Keep this in mind when choosing video dimensions, as longer sequences require more GPU memory.
|
| 212 |
+
|
| 213 |
+
> [!WARNING]
|
| 214 |
+
> When training with multiple resolution buckets, you must use a batch size of 1
|
| 215 |
+
> (i.e., set `optimization.batch_size: 1` in your training config).
|
| 216 |
+
|
| 217 |
+
### 📁 Output Structure
|
| 218 |
+
|
| 219 |
+
The preprocessed data is saved in a `.precomputed` directory:
|
| 220 |
+
|
| 221 |
+
```
|
| 222 |
+
dataset/
|
| 223 |
+
└── .precomputed/
|
| 224 |
+
├── latents/ # Cached video latents
|
| 225 |
+
├── conditions/ # Cached text embeddings
|
| 226 |
+
├── audio_latents/ # (only if --with-audio) Cached audio latents
|
| 227 |
+
└── reference_latents/ # (only for IC-LoRA) Cached reference video latents
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
## 🪄 IC-LoRA Reference Video Preprocessing
|
| 231 |
+
|
| 232 |
+
For IC-LoRA training, you need to preprocess datasets that include reference videos.
|
| 233 |
+
Reference videos provide the conditioning input while target videos represent the desired transformed output.
|
| 234 |
+
|
| 235 |
+
### Dataset Format with Reference Videos
|
| 236 |
+
|
| 237 |
+
**JSON format:**
|
| 238 |
+
|
| 239 |
+
```json
|
| 240 |
+
[
|
| 241 |
+
{
|
| 242 |
+
"caption": "A cat playing with a ball of yarn",
|
| 243 |
+
"media_path": "videos/cat_playing.mp4",
|
| 244 |
+
"reference_path": "references/cat_playing_depth.mp4"
|
| 245 |
+
}
|
| 246 |
+
]
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
**JSONL format:**
|
| 250 |
+
|
| 251 |
+
```jsonl
|
| 252 |
+
{"caption": "A cat playing with a ball of yarn", "media_path": "videos/cat_playing.mp4", "reference_path": "references/cat_playing_depth.mp4"}
|
| 253 |
+
{"caption": "A dog running in the park", "media_path": "videos/dog_running.mp4", "reference_path": "references/dog_running_depth.mp4"}
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
### Preprocessing with Reference Videos
|
| 257 |
+
|
| 258 |
+
To preprocess a dataset with reference videos, add the `--reference-column` argument specifying the name of the field
|
| 259 |
+
in your dataset JSON/JSONL/CSV that contains the reference video paths:
|
| 260 |
+
|
| 261 |
+
```bash
|
| 262 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 263 |
+
--resolution-buckets "960x544x49" \
|
| 264 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 265 |
+
--text-encoder-path /path/to/gemma-model \
|
| 266 |
+
--reference-column "reference_path"
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
This will create an additional `reference_latents/` directory containing the preprocessed reference video latents.
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
### Generating Reference Videos
|
| 273 |
+
|
| 274 |
+
**Dataset Requirements for IC-LoRA:**
|
| 275 |
+
|
| 276 |
+
- Your dataset must contain paired videos where each target video has a corresponding reference video
|
| 277 |
+
- Reference and target videos must have *identical* resolution and length
|
| 278 |
+
- Both reference and target videos should be preprocessed together using the same resolution buckets
|
| 279 |
+
|
| 280 |
+
We provide an example script, [`scripts/compute_reference.py`](../scripts/compute_reference.py), to generate reference
|
| 281 |
+
videos for a given dataset. The default implementation generates Canny edge reference videos.
|
| 282 |
+
|
| 283 |
+
```bash
|
| 284 |
+
uv run python scripts/compute_reference.py scenes_output_dir/ \
|
| 285 |
+
--output scenes_output_dir/dataset.json
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
The script accepts a JSON file as the dataset configuration and updates it in-place by adding the filenames of the generated reference videos.
|
| 289 |
+
|
| 290 |
+
If you want to generate a different type of condition (depth maps, pose skeletons, etc.), modify or replace the `compute_reference()` function within this script.
|
| 291 |
+
|
| 292 |
+
### Example Dataset
|
| 293 |
+
|
| 294 |
+
For reference, see our **[Canny Control Dataset](https://huggingface.co/datasets/Lightricks/Canny-Control-Dataset)** which demonstrates proper IC-LoRA dataset structure with paired videos and Canny edge maps.
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
## 🎯 LoRA Trigger Words
|
| 298 |
+
|
| 299 |
+
When training a LoRA, you can specify a trigger token that will be prepended to all captions:
|
| 300 |
+
|
| 301 |
+
```bash
|
| 302 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 303 |
+
--resolution-buckets "960x544x49" \
|
| 304 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 305 |
+
--text-encoder-path /path/to/gemma-model \
|
| 306 |
+
--lora-trigger "MYTRIGGER"
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
This acts as a trigger word that activates the LoRA during inference when you include the same token in your prompts.
|
| 310 |
+
|
| 311 |
+
> [!NOTE]
|
| 312 |
+
> There is no need to manually insert the trigger word into your dataset JSON/JSONL/CSV file.
|
| 313 |
+
> The trigger word specified with `--lora-trigger` is automatically prepended to each caption during preprocessing.
|
| 314 |
+
|
| 315 |
+
## 🔍 Decoding Videos for Verification
|
| 316 |
+
|
| 317 |
+
If you add the `--decode` flag, the script will VAE-decode the precomputed latents and save the resulting videos
|
| 318 |
+
in `.precomputed/decoded_videos`. When audio preprocessing is enabled (`--with-audio`), audio latents will also be
|
| 319 |
+
decoded and saved to `.precomputed/decoded_audio`. This allows you to visually and audibly inspect the processed data.
|
| 320 |
+
|
| 321 |
+
```bash
|
| 322 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 323 |
+
--resolution-buckets "960x544x49" \
|
| 324 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 325 |
+
--text-encoder-path /path/to/gemma-model \
|
| 326 |
+
--decode
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
For single-frame images, the decoded latents will be saved as PNG files rather than MP4 videos.
|
| 330 |
+
|
| 331 |
+
## 🚀 Next Steps
|
| 332 |
+
|
| 333 |
+
Once your dataset is preprocessed, you can proceed to:
|
| 334 |
+
|
| 335 |
+
- Configure your training parameters in [Configuration Reference](configuration-reference.md)
|
| 336 |
+
- Choose your training approach in [Training Modes](training-modes.md)
|
| 337 |
+
- Start training with the [Training Guide](training-guide.md)
|
| 338 |
+
|
| 339 |
+
> [!TIP]
|
| 340 |
+
> If your training recipe requires additional preprocessed data (e.g., masks, conditioning signals), see
|
| 341 |
+
> [Implementing Custom Training Strategies](custom-training-strategies.md) for guidance on extending the
|
| 342 |
+
> preprocessing pipeline.
|
packages/ltx-trainer/docs/quick-start.md
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quick Start Guide
|
| 2 |
+
|
| 3 |
+
Get up and running with LTX-2 training in just a few steps!
|
| 4 |
+
|
| 5 |
+
## 📋 Prerequisites
|
| 6 |
+
|
| 7 |
+
Before you begin, ensure you have:
|
| 8 |
+
|
| 9 |
+
1. **LTX-2 Model Checkpoint** - A local `.safetensors` file containing the LTX-2 model weights.
|
| 10 |
+
Download `ltx-2-19b-dev.safetensors` from: [HuggingFace Hub](https://huggingface.co/Lightricks/LTX-2)
|
| 11 |
+
2. **Gemma Text Encoder** - A local directory containing the Gemma model (required for LTX-2).
|
| 12 |
+
Download from: [HuggingFace Hub](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/)
|
| 13 |
+
3. **Linux with CUDA** - The trainer requires `triton` which is Linux-only
|
| 14 |
+
4. **GPU with sufficient VRAM** - 80GB recommended for the standard config. For GPUs with 32GB VRAM (e.g., RTX 5090),
|
| 15 |
+
use the [low VRAM config](../configs/ltx2_av_lora_low_vram.yaml) which enables INT8 quantization and other
|
| 16 |
+
memory optimizations
|
| 17 |
+
|
| 18 |
+
## ⚡ Installation
|
| 19 |
+
|
| 20 |
+
First, install [uv](https://docs.astral.sh/uv/getting-started/installation/) if you haven't already.
|
| 21 |
+
Then clone the repository and install the dependencies:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
git clone https://github.com/Lightricks/LTX-2
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
The `ltx-trainer` package is part of the `LTX-2` monorepo. Install the dependencies from the repository root,
|
| 28 |
+
then navigate to the trainer package:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
# From the repository root
|
| 32 |
+
uv sync
|
| 33 |
+
cd packages/ltx-trainer
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
> [!NOTE]
|
| 37 |
+
> The trainer depends on [`ltx-core`](../../ltx-core/) and [`ltx-pipelines`](../../ltx-pipelines/)
|
| 38 |
+
> packages which are automatically installed from the monorepo.
|
| 39 |
+
|
| 40 |
+
## 🏋 Training Workflow
|
| 41 |
+
|
| 42 |
+
### 1. Prepare Your Dataset
|
| 43 |
+
|
| 44 |
+
Organize your videos and captions, then preprocess them:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# Split long videos into scenes (optional)
|
| 48 |
+
uv run python scripts/split_scenes.py input.mp4 scenes_output_dir/ --filter-shorter-than 5s
|
| 49 |
+
|
| 50 |
+
# Generate captions for videos (optional)
|
| 51 |
+
uv run python scripts/caption_videos.py scenes_output_dir/ --output dataset.json
|
| 52 |
+
|
| 53 |
+
# Preprocess the dataset (compute latents and embeddings)
|
| 54 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 55 |
+
--resolution-buckets "960x544x49" \
|
| 56 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 57 |
+
--text-encoder-path /path/to/gemma-model
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
See [Dataset Preparation](dataset-preparation.md) for detailed instructions.
|
| 61 |
+
|
| 62 |
+
### 2. Configure Training
|
| 63 |
+
|
| 64 |
+
Create or modify a configuration YAML file. Start with one of the example configs:
|
| 65 |
+
|
| 66 |
+
- [`configs/ltx2_av_lora.yaml`](../configs/ltx2_av_lora.yaml) - Audio-video LoRA training
|
| 67 |
+
- [`configs/ltx2_av_lora_low_vram.yaml`](../configs/ltx2_av_lora_low_vram.yaml) - Audio-video LoRA training (optimized for 32GB VRAM)
|
| 68 |
+
- [`configs/ltx2_v2v_ic_lora.yaml`](../configs/ltx2_v2v_ic_lora.yaml) - IC-LoRA video-to-video
|
| 69 |
+
|
| 70 |
+
Key settings to update:
|
| 71 |
+
|
| 72 |
+
```yaml
|
| 73 |
+
model:
|
| 74 |
+
model_path: "/path/to/ltx-2-model.safetensors"
|
| 75 |
+
text_encoder_path: "/path/to/gemma-model"
|
| 76 |
+
|
| 77 |
+
data:
|
| 78 |
+
preprocessed_data_root: "/path/to/preprocessed/data"
|
| 79 |
+
|
| 80 |
+
output_dir: "outputs/my_training_run"
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
See [Configuration Reference](configuration-reference.md) for all available options.
|
| 84 |
+
|
| 85 |
+
### 3. Start Training
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
uv run python scripts/train.py configs/ltx2_av_lora.yaml
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
For multi-GPU training:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
See [Training Guide](training-guide.md) for distributed training and advanced options.
|
| 98 |
+
|
| 99 |
+
## 🎯 Training Modes
|
| 100 |
+
|
| 101 |
+
The trainer supports several training modes:
|
| 102 |
+
|
| 103 |
+
| Mode | Description | Config Example |
|
| 104 |
+
|----------------------|--------------------------------|--------------------------------------------|
|
| 105 |
+
| **LoRA** | Efficient adapter training | `training_strategy.name: "text_to_video"` |
|
| 106 |
+
| **Audio-Video LoRA** | Joint audio-video training | `training_strategy.with_audio: true` |
|
| 107 |
+
| **IC-LoRA** | Video-to-video transformations | `training_strategy.name: "video_to_video"` |
|
| 108 |
+
| **Full Fine-tuning** | Full model training | `model.training_mode: "full"` |
|
| 109 |
+
|
| 110 |
+
See [Training Modes](training-modes.md) for detailed explanations,
|
| 111 |
+
or [Custom Training Strategies](custom-training-strategies.md) if you need to implement your own training recipe.
|
| 112 |
+
|
| 113 |
+
## Next Steps
|
| 114 |
+
|
| 115 |
+
Once you've completed your first training run, you can:
|
| 116 |
+
|
| 117 |
+
- **Use your trained LoRA for inference** - The [`ltx-pipelines`](../../ltx-pipelines/) package provides
|
| 118 |
+
production-ready inference
|
| 119 |
+
pipelines for various use cases (T2V, I2V, IC-LoRA, etc.). See the package documentation for details.
|
| 120 |
+
- Learn more about [Dataset Preparation](dataset-preparation.md) for advanced preprocessing
|
| 121 |
+
- Explore different [Training Modes](training-modes.md) (LoRA, Audio-Video, IC-LoRA)
|
| 122 |
+
- Dive deeper into [Training Configuration](configuration-reference.md)
|
| 123 |
+
- Understand the model architecture in [LTX-Core Documentation](../../ltx-core/README.md)
|
| 124 |
+
|
| 125 |
+
## Need Help?
|
| 126 |
+
|
| 127 |
+
If you run into issues at any step, see the [Troubleshooting Guide](troubleshooting.md) for solutions to common
|
| 128 |
+
problems.
|
| 129 |
+
|
| 130 |
+
Join our [Discord community](https://discord.gg/ltxplatform) for real-time help and discussion!
|
packages/ltx-trainer/docs/training-guide.md
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Guide
|
| 2 |
+
|
| 3 |
+
This guide covers how to run training jobs, from basic single-GPU training to advanced distributed setups and automatic
|
| 4 |
+
model uploads.
|
| 5 |
+
|
| 6 |
+
## ⚡ Basic Training (Single GPU)
|
| 7 |
+
|
| 8 |
+
After preprocessing your dataset and preparing a configuration file, you can start training using the trainer script:
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
uv run python scripts/train.py configs/ltx2_av_lora.yaml
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
The trainer will:
|
| 15 |
+
|
| 16 |
+
1. **Load your configuration** and validate all parameters
|
| 17 |
+
2. **Initialize models** and apply optimizations
|
| 18 |
+
3. **Run the training loop** with progress tracking
|
| 19 |
+
4. **Generate validation videos** (if configured)
|
| 20 |
+
5. **Save the trained weights** in your output directory
|
| 21 |
+
|
| 22 |
+
### Output Files
|
| 23 |
+
|
| 24 |
+
**For LoRA training:**
|
| 25 |
+
|
| 26 |
+
- `lora_weights.safetensors` - Main LoRA weights file
|
| 27 |
+
- `training_config.yaml` - Copy of training configuration
|
| 28 |
+
- `validation_samples/` - Generated validation videos (if enabled)
|
| 29 |
+
|
| 30 |
+
**For full model fine-tuning:**
|
| 31 |
+
|
| 32 |
+
- `model_weights.safetensors` - Full model weights
|
| 33 |
+
- `training_config.yaml` - Copy of training configuration
|
| 34 |
+
- `validation_samples/` - Generated validation videos (if enabled)
|
| 35 |
+
|
| 36 |
+
## 🖥️ Distributed / Multi-GPU Training
|
| 37 |
+
|
| 38 |
+
We use Hugging Face 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) for multi-GPU DDP and FSDP.
|
| 39 |
+
|
| 40 |
+
### Configure Accelerate
|
| 41 |
+
|
| 42 |
+
Run the interactive wizard once to set up your environment (DDP / FSDP, GPU count, etc.):
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
uv run accelerate config
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
This stores your preferences in `~/.cache/huggingface/accelerate/default_config.yaml`.
|
| 49 |
+
|
| 50 |
+
### Use the Provided Accelerate Configs (Recommended)
|
| 51 |
+
|
| 52 |
+
We include ready-to-use Accelerate config files in `configs/accelerate/`:
|
| 53 |
+
|
| 54 |
+
- [ddp.yaml](../configs/accelerate/ddp.yaml) — Standard DDP
|
| 55 |
+
- [ddp_compile.yaml](../configs/accelerate/ddp_compile.yaml) — DDP with `torch.compile` (Inductor)
|
| 56 |
+
- [fsdp.yaml](../configs/accelerate/fsdp.yaml) — Standard FSDP (auto-wraps `BasicAVTransformerBlock`)
|
| 57 |
+
- [fsdp_compile.yaml](../configs/accelerate/fsdp_compile.yaml) — FSDP with `torch.compile` (Inductor)
|
| 58 |
+
|
| 59 |
+
Launch with a specific config using `--config_file`:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
# DDP (2 GPUs shown as example)
|
| 63 |
+
CUDA_VISIBLE_DEVICES=0,1 \
|
| 64 |
+
uv run accelerate launch --config_file configs/accelerate/ddp.yaml \
|
| 65 |
+
scripts/train.py configs/ltx2_av_lora.yaml
|
| 66 |
+
|
| 67 |
+
# DDP + torch.compile
|
| 68 |
+
CUDA_VISIBLE_DEVICES=0,1 \
|
| 69 |
+
uv run accelerate launch --config_file configs/accelerate/ddp_compile.yaml \
|
| 70 |
+
scripts/train.py configs/ltx2_av_lora.yaml
|
| 71 |
+
|
| 72 |
+
# FSDP (4 GPUs shown as example)
|
| 73 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
| 74 |
+
uv run accelerate launch --config_file configs/accelerate/fsdp.yaml \
|
| 75 |
+
scripts/train.py configs/ltx2_av_lora.yaml
|
| 76 |
+
|
| 77 |
+
# FSDP + torch.compile
|
| 78 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
| 79 |
+
uv run accelerate launch --config_file configs/accelerate/fsdp_compile.yaml \
|
| 80 |
+
scripts/train.py configs/ltx2_av_lora.yaml
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
**Notes:**
|
| 84 |
+
|
| 85 |
+
- The number of processes is taken from the Accelerate config (`num_processes`). Override with `--num_processes X` or
|
| 86 |
+
restrict GPUs with `CUDA_VISIBLE_DEVICES`.
|
| 87 |
+
- The compile variants enable `torch.compile` with the Inductor backend via Accelerate's `dynamo_config`.
|
| 88 |
+
- FSDP configs auto-wrap the transformer blocks (`fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock`).
|
| 89 |
+
|
| 90 |
+
### Launch with Your Default Accelerate Config
|
| 91 |
+
|
| 92 |
+
If you prefer to use your default Accelerate profile:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
# Use settings from your default accelerate config
|
| 96 |
+
uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
|
| 97 |
+
|
| 98 |
+
# Override number of processes on the fly (e.g., 2 GPUs)
|
| 99 |
+
uv run accelerate launch --num_processes 2 scripts/train.py configs/ltx2_av_lora.yaml
|
| 100 |
+
|
| 101 |
+
# Select specific GPUs
|
| 102 |
+
CUDA_VISIBLE_DEVICES=0,1 uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
> [!TIP]
|
| 106 |
+
> You can disable the in-terminal progress bars with `--disable-progress-bars` flag in the trainer CLI if desired.
|
| 107 |
+
|
| 108 |
+
### Benefits of Distributed Training
|
| 109 |
+
|
| 110 |
+
- **Faster training**: Distribute workload across multiple GPUs
|
| 111 |
+
- **Larger effective batch sizes**: Combine gradients from multiple GPUs
|
| 112 |
+
- **Memory efficiency**: Each GPU handles a portion of the batch
|
| 113 |
+
|
| 114 |
+
> [!NOTE]
|
| 115 |
+
> Distributed training requires that all GPUs have sufficient memory for the model and batch size. The effective batch
|
| 116 |
+
> size becomes `batch_size × num_processes`.
|
| 117 |
+
|
| 118 |
+
## 🤗 Pushing Models to Hugging Face Hub
|
| 119 |
+
|
| 120 |
+
You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration:
|
| 121 |
+
|
| 122 |
+
```yaml
|
| 123 |
+
hub:
|
| 124 |
+
push_to_hub: true
|
| 125 |
+
hub_model_id: "your-username/your-model-name"
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Prerequisites
|
| 129 |
+
|
| 130 |
+
Before pushing, make sure you:
|
| 131 |
+
|
| 132 |
+
1. **Have a Hugging Face account** - Sign up at [huggingface.co](https://huggingface.co)
|
| 133 |
+
2. **Are logged in** via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable
|
| 134 |
+
3. **Have write access** to the specified repository (it will be created if it doesn't exist)
|
| 135 |
+
|
| 136 |
+
### Login Options
|
| 137 |
+
|
| 138 |
+
**Option 1: Interactive login**
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
uv run huggingface-cli login
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
**Option 2: Environment variable**
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
export HUGGING_FACE_HUB_TOKEN="your_token_here"
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### What Gets Uploaded
|
| 151 |
+
|
| 152 |
+
The trainer will automatically:
|
| 153 |
+
|
| 154 |
+
- **Create a model card** with training details and sample outputs
|
| 155 |
+
- **Upload model weights**
|
| 156 |
+
- **Push sample videos as GIFs** in the model card
|
| 157 |
+
- **Include training configuration and prompts**
|
| 158 |
+
|
| 159 |
+
## 📊 Weights & Biases Logging
|
| 160 |
+
|
| 161 |
+
Enable experiment tracking with W&B by adding to your configuration:
|
| 162 |
+
|
| 163 |
+
```yaml
|
| 164 |
+
wandb:
|
| 165 |
+
enabled: true
|
| 166 |
+
project: "ltx-2-trainer"
|
| 167 |
+
entity: null # Your W&B username or team
|
| 168 |
+
tags: [ "ltx2", "lora" ]
|
| 169 |
+
log_validation_videos: true
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
This will log:
|
| 173 |
+
|
| 174 |
+
- Training loss and learning rate
|
| 175 |
+
- Validation videos
|
| 176 |
+
- Model configuration
|
| 177 |
+
- Training progress
|
| 178 |
+
|
| 179 |
+
## 🚀 Next Steps
|
| 180 |
+
|
| 181 |
+
After training completes:
|
| 182 |
+
|
| 183 |
+
- **Run inference with your trained LoRA** - The [`ltx-pipelines`](../../ltx-pipelines/) package provides
|
| 184 |
+
production-ready inference
|
| 185 |
+
pipelines that support loading custom LoRAs. Available pipelines include text-to-video, image-to-video,
|
| 186 |
+
IC-LoRA video-to-video, and more. See the [`ltx-pipelines`](../../ltx-pipelines/) package for usage details.
|
| 187 |
+
- **Test your model** with validation prompts
|
| 188 |
+
- **Iterate and improve** based on validation results
|
| 189 |
+
- **Share your results** by pushing to Hugging Face Hub
|
| 190 |
+
|
| 191 |
+
## 💡 Tips for Successful Training
|
| 192 |
+
|
| 193 |
+
- **Start small**: Begin with a small dataset and a few hundred steps to verify everything works
|
| 194 |
+
- **Monitor validation**: Keep an eye on validation samples to catch overfitting
|
| 195 |
+
- **Adjust learning rate**: Lower learning rates often produce better results
|
| 196 |
+
- **Use gradient checkpointing**: Essential for training with limited GPU memory
|
| 197 |
+
- **Save checkpoints**: Regular checkpoints help recover from interruptions
|
| 198 |
+
|
| 199 |
+
## Need Help?
|
| 200 |
+
|
| 201 |
+
If you encounter issues during training, see the [Troubleshooting Guide](troubleshooting.md).
|
| 202 |
+
|
| 203 |
+
Join our [Discord community](https://discord.gg/ltxplatform) for real-time help!
|
packages/ltx-trainer/docs/training-modes.md
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Modes Guide
|
| 2 |
+
|
| 3 |
+
The trainer supports several training modes, each suited for different use cases and requirements.
|
| 4 |
+
|
| 5 |
+
## 🎯 Standard LoRA Training (Video-Only)
|
| 6 |
+
|
| 7 |
+
Standard LoRA (Low-Rank Adaptation) training fine-tunes the model by adding small, trainable adapter layers while
|
| 8 |
+
keeping the base model frozen. This approach:
|
| 9 |
+
|
| 10 |
+
- **Requires significantly less memory and compute** than full fine-tuning
|
| 11 |
+
- **Produces small, portable weight files** (typically a few hundred MB)
|
| 12 |
+
- **Is ideal for learning specific styles, effects, or concepts**
|
| 13 |
+
- **Can be easily combined with other LoRAs** during inference
|
| 14 |
+
|
| 15 |
+
Configure standard LoRA training with:
|
| 16 |
+
|
| 17 |
+
```yaml
|
| 18 |
+
model:
|
| 19 |
+
training_mode: "lora"
|
| 20 |
+
|
| 21 |
+
training_strategy:
|
| 22 |
+
name: "text_to_video"
|
| 23 |
+
first_frame_conditioning_p: 0.1
|
| 24 |
+
with_audio: false # Video-only training
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## 🔊 Audio-Video LoRA Training
|
| 28 |
+
|
| 29 |
+
LTX-2 supports joint audio-video generation. You can train LoRA adapters that affect both video and audio output:
|
| 30 |
+
|
| 31 |
+
- **Synchronized audio-video generation** - Audio matches the visual content
|
| 32 |
+
- **Same efficient LoRA approach** - Just enable audio training
|
| 33 |
+
- **Requires audio latents** - Dataset must include preprocessed audio
|
| 34 |
+
|
| 35 |
+
Configure audio-video training with:
|
| 36 |
+
|
| 37 |
+
```yaml
|
| 38 |
+
model:
|
| 39 |
+
training_mode: "lora"
|
| 40 |
+
|
| 41 |
+
training_strategy:
|
| 42 |
+
name: "text_to_video"
|
| 43 |
+
first_frame_conditioning_p: 0.1
|
| 44 |
+
with_audio: true # Enable audio training
|
| 45 |
+
audio_latents_dir: "audio_latents" # Directory containing audio latents
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**Example configuration file:**
|
| 49 |
+
|
| 50 |
+
- 📄 [Audio-Video LoRA Training](../configs/ltx2_av_lora.yaml)
|
| 51 |
+
|
| 52 |
+
**Dataset structure for audio-video training:**
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
preprocessed_data_root/
|
| 56 |
+
├── latents/ # Video latents
|
| 57 |
+
├── conditions/ # Text embeddings
|
| 58 |
+
└── audio_latents/ # Audio latents (required when with_audio: true)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
> [!IMPORTANT]
|
| 62 |
+
> When training audio-video LoRAs, ensure your `target_modules` configuration captures video, audio, and
|
| 63 |
+
> cross-modal attention branches. Use patterns like `"to_k"` instead of `"attn1.to_k"` to match:
|
| 64 |
+
> - Video modules: `attn1.to_k`, `attn2.to_k`
|
| 65 |
+
> - Audio modules: `audio_attn1.to_k`, `audio_attn2.to_k`
|
| 66 |
+
> - Cross-modal modules: `audio_to_video_attn.to_k`, `video_to_audio_attn.to_k`
|
| 67 |
+
>
|
| 68 |
+
> The cross-modal attention modules (`audio_to_video_attn` and `video_to_audio_attn`) enable bidirectional
|
| 69 |
+
> information flow between audio and video, which is critical for synchronized audiovisual generation.
|
| 70 |
+
> See [Understanding Target Modules](configuration-reference.md#understanding-target-modules) for detailed guidance.
|
| 71 |
+
|
| 72 |
+
> [!NOTE]
|
| 73 |
+
> You can generate audio during validation even if you're not training the audio branch.
|
| 74 |
+
> Set `validation.generate_audio: true` independently of `training_strategy.with_audio`.
|
| 75 |
+
|
| 76 |
+
## 🔥 Full Model Fine-tuning
|
| 77 |
+
|
| 78 |
+
Full model fine-tuning updates all parameters of the base model, providing maximum flexibility but
|
| 79 |
+
requiring substantial computational resources and larger training datasets:
|
| 80 |
+
|
| 81 |
+
- **Offers the highest potential quality and capability improvements**
|
| 82 |
+
- **Requires multiple GPUs** and distributed training techniques (e.g., FSDP)
|
| 83 |
+
- **Produces large checkpoint files** (several GB)
|
| 84 |
+
- **Best for major model adaptations** or when LoRA limitations are reached
|
| 85 |
+
|
| 86 |
+
Configure full fine-tuning with:
|
| 87 |
+
|
| 88 |
+
```yaml
|
| 89 |
+
model:
|
| 90 |
+
training_mode: "full"
|
| 91 |
+
|
| 92 |
+
training_strategy:
|
| 93 |
+
name: "text_to_video"
|
| 94 |
+
first_frame_conditioning_p: 0.1
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
> [!IMPORTANT]
|
| 98 |
+
> Full fine-tuning of LTX-2 requires multiple high-end GPUs (e.g., 4-8× H100 80GB) and distributed
|
| 99 |
+
> training with FSDP. See [Training Guide](training-guide.md) for multi-GPU setup instructions.
|
| 100 |
+
|
| 101 |
+
## 🔄 In-Context LoRA (IC-LoRA) Training
|
| 102 |
+
|
| 103 |
+
IC-LoRA is a specialized training mode for video-to-video transformations.
|
| 104 |
+
Unlike standard training modes that learn from individual videos, IC-LoRA learns transformations from pairs of videos.
|
| 105 |
+
IC-LoRA enables a wide range of advanced video-to-video applications, such as:
|
| 106 |
+
|
| 107 |
+
- **Control adapters** (e.g., Depth, Pose): Learn to map from a control signal (like a depth map or pose skeleton) to a
|
| 108 |
+
target video
|
| 109 |
+
- **Video deblurring**: Transform blurry input videos into sharp, high-quality outputs
|
| 110 |
+
- **Style transfer**: Apply the style of a reference video to a target video sequence
|
| 111 |
+
- **Colorization**: Convert grayscale reference videos into colorized outputs
|
| 112 |
+
- **Restoration and enhancement**: Denoise, upscale, or restore old or degraded videos
|
| 113 |
+
|
| 114 |
+
By providing paired reference and target videos, IC-LoRA can learn complex transformations that go beyond caption-based
|
| 115 |
+
conditioning.
|
| 116 |
+
|
| 117 |
+
IC-LoRA training fundamentally differs from standard LoRA and full fine-tuning:
|
| 118 |
+
|
| 119 |
+
- **Reference videos** provide clean, unnoised conditioning input showing the "before" state
|
| 120 |
+
- **Target videos** are noised during training and represent the desired "after" state
|
| 121 |
+
- **The model learns transformations** from reference videos to target videos
|
| 122 |
+
- **Loss is applied only to the target portion**, not the reference
|
| 123 |
+
- **Training and inference time increase significantly** due to the doubled sequence length
|
| 124 |
+
|
| 125 |
+
To enable IC-LoRA training, configure your YAML file with:
|
| 126 |
+
|
| 127 |
+
```yaml
|
| 128 |
+
model:
|
| 129 |
+
training_mode: "lora" # Required: IC-LoRA uses LoRA mode
|
| 130 |
+
|
| 131 |
+
training_strategy:
|
| 132 |
+
name: "video_to_video"
|
| 133 |
+
first_frame_conditioning_p: 0.1
|
| 134 |
+
reference_latents_dir: "reference_latents" # Directory for reference video latents
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
**Example configuration file:**
|
| 138 |
+
|
| 139 |
+
- 📄 [IC-LoRA Training](../configs/ltx2_v2v_ic_lora.yaml) - Video-to-video transformation training
|
| 140 |
+
|
| 141 |
+
### Dataset Requirements for IC-LoRA
|
| 142 |
+
|
| 143 |
+
- Your dataset must contain **paired videos** where each target video has a corresponding reference video
|
| 144 |
+
- Reference and target videos must have the **same frame count** (length)
|
| 145 |
+
- Reference videos can optionally be at **lower spatial resolution** than target videos (
|
| 146 |
+
see [Scaled Reference Conditioning](#scaled-reference-conditioning) below)
|
| 147 |
+
- Both reference and target videos should be **preprocessed** before training
|
| 148 |
+
|
| 149 |
+
**Dataset structure for IC-LoRA training:**
|
| 150 |
+
|
| 151 |
+
```
|
| 152 |
+
preprocessed_data_root/
|
| 153 |
+
├── latents/ # Target video latents (what the model learns to generate)
|
| 154 |
+
├── conditions/ # Text embeddings for each video
|
| 155 |
+
└── reference_latents/ # Reference video latents (conditioning input)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Generating Reference Videos
|
| 159 |
+
|
| 160 |
+
We provide an example script to generate reference videos (e.g., Canny edge maps) for a given dataset.
|
| 161 |
+
The script takes a JSON file as input (e.g., output of `caption_videos.py`) and updates it with the generated reference
|
| 162 |
+
video paths.
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
uv run python scripts/compute_reference.py scenes_output_dir/ \
|
| 166 |
+
--output scenes_output_dir/dataset.json
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
To compute a different condition (depth maps, pose skeletons, etc.), modify the `compute_reference()` function in the
|
| 170 |
+
script.
|
| 171 |
+
|
| 172 |
+
### Configuration Requirements for IC-LoRA
|
| 173 |
+
|
| 174 |
+
- You **must** provide `reference_videos` in your validation configuration when using IC-LoRA training
|
| 175 |
+
- The number of reference videos must match the number of validation prompts
|
| 176 |
+
|
| 177 |
+
Example validation configuration for IC-LoRA:
|
| 178 |
+
|
| 179 |
+
```yaml
|
| 180 |
+
validation:
|
| 181 |
+
prompts:
|
| 182 |
+
- "First prompt describing the desired output"
|
| 183 |
+
- "Second prompt describing the desired output"
|
| 184 |
+
reference_videos:
|
| 185 |
+
- "/path/to/reference1.mp4"
|
| 186 |
+
- "/path/to/reference2.mp4"
|
| 187 |
+
reference_downscale_factor: 1 # Set to match preprocessing (e.g., 2 for half resolution)
|
| 188 |
+
include_reference_in_output: true # Show reference side-by-side with output
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### Scaled Reference Conditioning
|
| 192 |
+
|
| 193 |
+
For more efficient training and inference, you can use **downscaled reference videos** while keeping target videos at
|
| 194 |
+
full resolution. This reduces the number of conditioning tokens, leading to:
|
| 195 |
+
|
| 196 |
+
- **Faster training** due to shorter sequence lengths
|
| 197 |
+
- **Faster inference** with reduced memory usage
|
| 198 |
+
- **Same aspect ratio** maintained between reference and target
|
| 199 |
+
|
| 200 |
+
#### How It Works
|
| 201 |
+
|
| 202 |
+
When the reference video has resolution `H/n × W/n` and the target video has resolution `H × W`, the trainer
|
| 203 |
+
automatically detects this scale factor `n` and adjusts the positional encodings so that the reference positions
|
| 204 |
+
map to the correct locations in the target coordinate space.
|
| 205 |
+
|
| 206 |
+
#### Preprocessing Datasets with Scaled References
|
| 207 |
+
|
| 208 |
+
Use the `--reference-downscale-factor` option when running `process_dataset.py`:
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
# Process dataset with scaled reference videos (half resolution)
|
| 212 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 213 |
+
--resolution-buckets 768x768x25 \
|
| 214 |
+
--model-path /path/to/ltx2.safetensors \
|
| 215 |
+
--text-encoder-path /path/to/gemma \
|
| 216 |
+
--reference-column "reference_path" \
|
| 217 |
+
--reference-downscale-factor 2
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
This will:
|
| 221 |
+
|
| 222 |
+
- Process target videos at 768×768 resolution
|
| 223 |
+
- Process reference videos at 384×384 resolution (768 / 2)
|
| 224 |
+
- The trainer will automatically infer the scale factor from the dimension ratio
|
| 225 |
+
|
| 226 |
+
**Important**: Set `reference_downscale_factor: 2` in your validation configuration to match the preprocessing:
|
| 227 |
+
|
| 228 |
+
```yaml
|
| 229 |
+
validation:
|
| 230 |
+
reference_downscale_factor: 2 # Must match the preprocessing factor
|
| 231 |
+
reference_videos:
|
| 232 |
+
- "/path/to/reference1.mp4"
|
| 233 |
+
- "/path/to/reference2.mp4"
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
> [!NOTE]
|
| 237 |
+
> The scale factor must be a positive integer, and all dimensions must be divisible by 32.
|
| 238 |
+
> Common scale factors are 1 (no scaling), 2 (half resolution), or 4 (quarter resolution).
|
| 239 |
+
|
| 240 |
+
## 📊 Training Mode Comparison
|
| 241 |
+
|
| 242 |
+
| Aspect | LoRA | Audio-Video LoRA | Full Fine-tuning | IC-LoRA |
|
| 243 |
+
|----------------------|--------------------------------|--------------------------------|------------------|--------------------------------|
|
| 244 |
+
| **Memory Usage** | Low | Low-Medium | High | Medium |
|
| 245 |
+
| **Training Speed** | Fast | Fast | Slow | Medium |
|
| 246 |
+
| **Output Size** | 100MB-few GB (depends on rank) | 100MB-few GB (depends on rank) | Tens of GB | 100MB-few GB (depends on rank) |
|
| 247 |
+
| **Flexibility** | Medium | Medium | High | Specialized |
|
| 248 |
+
| **Audio Support** | Optional | Yes | Optional | No |
|
| 249 |
+
| **Reference Videos** | No | No | No | Yes (required) |
|
| 250 |
+
|
| 251 |
+
## 🎬 Using Trained Models for Inference
|
| 252 |
+
|
| 253 |
+
After training, use the [`ltx-pipelines`](../../ltx-pipelines/) package for production inference with your trained
|
| 254 |
+
LoRAs:
|
| 255 |
+
|
| 256 |
+
| Training Mode | Recommended Pipeline |
|
| 257 |
+
|-------------------------|-------------------------------------------------------|
|
| 258 |
+
| LoRA / Audio-Video LoRA | `TI2VidOneStagePipeline` or `TI2VidTwoStagesPipeline` |
|
| 259 |
+
| IC-LoRA | `ICLoraPipeline` |
|
| 260 |
+
|
| 261 |
+
All pipelines support loading custom LoRAs via the `loras` parameter. See the [`ltx-pipelines`](../../ltx-pipelines/)
|
| 262 |
+
package
|
| 263 |
+
documentation for detailed usage instructions.
|
| 264 |
+
|
| 265 |
+
## 🚀 Next Steps
|
| 266 |
+
|
| 267 |
+
Once you've chosen your training mode:
|
| 268 |
+
|
| 269 |
+
- Set up your dataset using [Dataset Preparation](dataset-preparation.md)
|
| 270 |
+
- Configure your training parameters in [Configuration Reference](configuration-reference.md)
|
| 271 |
+
- Start training with the [Training Guide](training-guide.md)
|
| 272 |
+
|
| 273 |
+
> [!TIP]
|
| 274 |
+
> Need a training mode that's not covered here?
|
| 275 |
+
> See [Implementing Custom Training Strategies](custom-training-strategies.md)
|
| 276 |
+
> to learn how to create your own strategy for specialized use cases like video inpainting, audio-only training, or
|
| 277 |
+
> custom conditioning.
|
packages/ltx-trainer/docs/troubleshooting.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Troubleshooting Guide
|
| 2 |
+
|
| 3 |
+
This guide covers common issues and solutions when training with the LTX-2 trainer.
|
| 4 |
+
|
| 5 |
+
## 🔧 VRAM and Memory Issues
|
| 6 |
+
|
| 7 |
+
Memory management is crucial for successful training with LTX-2.
|
| 8 |
+
|
| 9 |
+
> [!TIP]
|
| 10 |
+
> For GPUs with 32GB VRAM, use the pre-configured low VRAM config:
|
| 11 |
+
> [`configs/ltx2_av_lora_low_vram.yaml`](../configs/ltx2_av_lora_low_vram.yaml)
|
| 12 |
+
> which combines 8-bit optimizer, INT8 quantization, and reduced LoRA rank.
|
| 13 |
+
|
| 14 |
+
### Memory Optimization Techniques
|
| 15 |
+
|
| 16 |
+
#### 1. Enable Gradient Checkpointing
|
| 17 |
+
|
| 18 |
+
Gradient checkpointing trades training speed for memory savings. **Highly recommended** for most training runs:
|
| 19 |
+
|
| 20 |
+
```yaml
|
| 21 |
+
optimization:
|
| 22 |
+
enable_gradient_checkpointing: true
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
#### 2. Enable 8-bit Text Encoder
|
| 26 |
+
|
| 27 |
+
Load the Gemma text encoder in 8-bit precision to save GPU memory:
|
| 28 |
+
|
| 29 |
+
```yaml
|
| 30 |
+
acceleration:
|
| 31 |
+
load_text_encoder_in_8bit: true
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
#### 3. Reduce Batch Size
|
| 35 |
+
|
| 36 |
+
Lower the batch size if you encounter out-of-memory errors:
|
| 37 |
+
|
| 38 |
+
```yaml
|
| 39 |
+
optimization:
|
| 40 |
+
batch_size: 1 # Start with 1 and increase gradually
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Use gradient accumulation to maintain a larger effective batch size:
|
| 44 |
+
|
| 45 |
+
```yaml
|
| 46 |
+
optimization:
|
| 47 |
+
batch_size: 1
|
| 48 |
+
gradient_accumulation_steps: 4 # Effective batch size = 4
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
#### 4. Use Lower Resolution
|
| 52 |
+
|
| 53 |
+
Reduce spatial or temporal dimensions to save memory:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# Smaller spatial resolution
|
| 57 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 58 |
+
--resolution-buckets "512x512x49" \
|
| 59 |
+
--model-path /path/to/model.safetensors \
|
| 60 |
+
--text-encoder-path /path/to/gemma
|
| 61 |
+
|
| 62 |
+
# Fewer frames
|
| 63 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 64 |
+
--resolution-buckets "960x544x25" \
|
| 65 |
+
--model-path /path/to/model.safetensors \
|
| 66 |
+
--text-encoder-path /path/to/gemma
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
#### 5. Enable Model Quantization
|
| 70 |
+
|
| 71 |
+
Use quantization to reduce memory usage:
|
| 72 |
+
|
| 73 |
+
```yaml
|
| 74 |
+
acceleration:
|
| 75 |
+
quantization: "int8-quanto" # Options: int8-quanto, int4-quanto, fp8-quanto
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
#### 6. Use 8-bit Optimizer
|
| 79 |
+
|
| 80 |
+
The 8-bit AdamW optimizer uses less memory:
|
| 81 |
+
|
| 82 |
+
```yaml
|
| 83 |
+
optimization:
|
| 84 |
+
optimizer_type: "adamw8bit"
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## ⚠️ Common Usage Issues
|
| 90 |
+
|
| 91 |
+
### Issue: "No module named 'ltx_trainer'" Error
|
| 92 |
+
|
| 93 |
+
**Solution:**
|
| 94 |
+
Ensure you've installed the dependencies and are using `uv run` to execute scripts:
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
# From the repository root
|
| 98 |
+
uv sync
|
| 99 |
+
cd packages/ltx-trainer
|
| 100 |
+
uv run python scripts/train.py configs/ltx2_av_lora.yaml
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
> [!TIP]
|
| 104 |
+
> Always use `uv run` to execute Python scripts. This automatically uses the correct virtual environment
|
| 105 |
+
> without requiring manual activation.
|
| 106 |
+
|
| 107 |
+
### Issue: "Gemma model path is not a directory" Error
|
| 108 |
+
|
| 109 |
+
**Solution:**
|
| 110 |
+
The `text_encoder_path` must point to a directory containing the Gemma model, not a file:
|
| 111 |
+
|
| 112 |
+
```yaml
|
| 113 |
+
model:
|
| 114 |
+
model_path: "/path/to/ltx-2-model.safetensors" # File path
|
| 115 |
+
text_encoder_path: "/path/to/gemma-model/" # Directory path
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Issue: "Model path does not exist" Error
|
| 119 |
+
|
| 120 |
+
**Solution:**
|
| 121 |
+
LTX-2 requires local model paths. URLs are not supported:
|
| 122 |
+
|
| 123 |
+
```yaml
|
| 124 |
+
# ✅ Correct - local path
|
| 125 |
+
model:
|
| 126 |
+
model_path: "/path/to/ltx-2-model.safetensors"
|
| 127 |
+
|
| 128 |
+
# ❌ Wrong - URL not supported
|
| 129 |
+
model:
|
| 130 |
+
model_path: "https://huggingface.co/..."
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### Issue: "Frames must satisfy frames % 8 == 1" Error
|
| 134 |
+
|
| 135 |
+
**Solution:**
|
| 136 |
+
LTX-2 requires the number of frames to satisfy `frames % 8 == 1`:
|
| 137 |
+
|
| 138 |
+
- ✅ Valid: 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121
|
| 139 |
+
- ❌ Invalid: 24, 32, 48, 64, 100
|
| 140 |
+
|
| 141 |
+
### Issue: Slow Training Speed
|
| 142 |
+
|
| 143 |
+
**Optimizations:**
|
| 144 |
+
|
| 145 |
+
1. **Disable gradient checkpointing** (if you have enough VRAM):
|
| 146 |
+
|
| 147 |
+
```yaml
|
| 148 |
+
optimization:
|
| 149 |
+
enable_gradient_checkpointing: false
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
2. **Use torch.compile** via Accelerate:
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
uv run accelerate launch --config_file configs/accelerate/ddp_compile.yaml \
|
| 157 |
+
scripts/train.py configs/ltx2_av_lora.yaml
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
### Issue: Poor Quality Validation Outputs
|
| 161 |
+
|
| 162 |
+
**Solutions:**
|
| 163 |
+
|
| 164 |
+
1. **Use Image-to-Video Validation:**
|
| 165 |
+
For more reliable validation, use image-to-video (first-frame conditioning) rather than pure text-to-video:
|
| 166 |
+
|
| 167 |
+
```yaml
|
| 168 |
+
validation:
|
| 169 |
+
prompts:
|
| 170 |
+
- "a professional portrait video of a person"
|
| 171 |
+
images:
|
| 172 |
+
- "/path/to/first_frame.png" # One image per prompt
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
2. **Increase inference steps:**
|
| 176 |
+
|
| 177 |
+
```yaml
|
| 178 |
+
validation:
|
| 179 |
+
inference_steps: 50 # Default is 30
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
3. **Adjust guidance settings:**
|
| 183 |
+
|
| 184 |
+
```yaml
|
| 185 |
+
validation:
|
| 186 |
+
guidance_scale: 4.0 # CFG scale (recommended: 4.0)
|
| 187 |
+
stg_scale: 1.0 # STG scale for temporal coherence (recommended: 1.0)
|
| 188 |
+
stg_blocks: [29] # Transformer block to perturb
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
4. **Check caption quality:**
|
| 192 |
+
Review and manually edit captions for accuracy if using auto-generated captions.
|
| 193 |
+
LTX-2 prefers long, detailed captions that describe both visual content and audio (e.g., ambient sounds, speech,
|
| 194 |
+
music).
|
| 195 |
+
|
| 196 |
+
5. **Check target modules:**
|
| 197 |
+
Ensure your `target_modules` configuration matches your training goals. For audio-video training,
|
| 198 |
+
use patterns that match both branches (e.g., `"to_k"` instead of `"attn1.to_k"`).
|
| 199 |
+
See [Understanding Target Modules](configuration-reference.md#understanding-target-modules) for details.
|
| 200 |
+
|
| 201 |
+
6. **Adjust LoRA rank:**
|
| 202 |
+
Try higher values for more capacity:
|
| 203 |
+
|
| 204 |
+
```yaml
|
| 205 |
+
lora:
|
| 206 |
+
rank: 64 # Or 128 for more capacity
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
7. **Increase training steps:**
|
| 210 |
+
|
| 211 |
+
```yaml
|
| 212 |
+
optimization:
|
| 213 |
+
steps: 3000
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
## 🔍 Debugging Tools
|
| 219 |
+
|
| 220 |
+
### Monitor GPU Memory Usage
|
| 221 |
+
|
| 222 |
+
Track memory usage during training:
|
| 223 |
+
|
| 224 |
+
```bash
|
| 225 |
+
# Watch GPU memory in real-time
|
| 226 |
+
watch -n 1 nvidia-smi
|
| 227 |
+
|
| 228 |
+
# Log memory usage to file
|
| 229 |
+
nvidia-smi --query-gpu=memory.used,memory.total --format=csv --loop=5 > memory_log.csv
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
### Verify Preprocessed Data
|
| 233 |
+
|
| 234 |
+
Decode latents to visualize the preprocessed videos:
|
| 235 |
+
|
| 236 |
+
```bash
|
| 237 |
+
uv run python scripts/decode_latents.py dataset/.precomputed/latents debug_output \
|
| 238 |
+
--model-path /path/to/model.safetensors
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
To also decode audio latents, add the `--with-audio` flag:
|
| 242 |
+
|
| 243 |
+
```bash
|
| 244 |
+
uv run python scripts/decode_latents.py dataset/.precomputed/latents debug_output \
|
| 245 |
+
--model-path /path/to/model.safetensors \
|
| 246 |
+
--with-audio
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
Compare decoded videos and audio with originals to ensure quality.
|
| 250 |
+
|
| 251 |
+
---
|
| 252 |
+
|
| 253 |
+
## 💡 Best Practices
|
| 254 |
+
|
| 255 |
+
### Before Training
|
| 256 |
+
|
| 257 |
+
- [ ] Test preprocessing with a small subset first
|
| 258 |
+
- [ ] Verify all video files are accessible
|
| 259 |
+
- [ ] Check available GPU memory
|
| 260 |
+
- [ ] Review configuration against hardware capabilities
|
| 261 |
+
- [ ] Ensure model and text encoder paths are correct
|
| 262 |
+
|
| 263 |
+
### During Training
|
| 264 |
+
|
| 265 |
+
- [ ] Monitor GPU memory usage
|
| 266 |
+
- [ ] Check loss convergence regularly
|
| 267 |
+
- [ ] Review validation samples periodically
|
| 268 |
+
- [ ] Save checkpoints frequently
|
| 269 |
+
|
| 270 |
+
### After Training
|
| 271 |
+
|
| 272 |
+
- [ ] Test trained model with diverse prompts
|
| 273 |
+
- [ ] Document training parameters and results
|
| 274 |
+
- [ ] Archive training data and configs
|
| 275 |
+
|
| 276 |
+
## 🆘 Getting Help
|
| 277 |
+
|
| 278 |
+
If you're still experiencing issues:
|
| 279 |
+
|
| 280 |
+
1. **Check logs:** Review console output for error details
|
| 281 |
+
2. **Search issues:** Look through GitHub issues for similar problems
|
| 282 |
+
3. **Provide details:** When reporting issues, include:
|
| 283 |
+
- Hardware specifications (GPU model, VRAM)
|
| 284 |
+
- Configuration file used
|
| 285 |
+
- Complete error message
|
| 286 |
+
- Steps to reproduce the issue
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
|
| 290 |
+
## 🤝 Join the Community
|
| 291 |
+
|
| 292 |
+
Have questions, want to share your results, or need real-time help?
|
| 293 |
+
Join our [community Discord server](https://discord.gg/ltxplatform)
|
| 294 |
+
to connect with other users and the development team!
|
| 295 |
+
|
| 296 |
+
- Get troubleshooting help
|
| 297 |
+
- Share your training results and workflows
|
| 298 |
+
- Stay up to date with announcements and updates
|
| 299 |
+
|
| 300 |
+
We look forward to seeing you there!
|
packages/ltx-trainer/docs/utility-scripts.md
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Utility Scripts Reference
|
| 2 |
+
|
| 3 |
+
This guide covers the various utility scripts available for preprocessing, conversion, and debugging tasks.
|
| 4 |
+
|
| 5 |
+
## 🎬 Dataset Processing Scripts
|
| 6 |
+
|
| 7 |
+
### Video Scene Splitting
|
| 8 |
+
|
| 9 |
+
The `scripts/split_scenes.py` script automatically splits long videos into shorter, coherent scenes.
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Basic scene splitting
|
| 13 |
+
uv run python scripts/split_scenes.py input.mp4 output_dir/ --filter-shorter-than 5s
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
**Key features:**
|
| 17 |
+
|
| 18 |
+
- **Automatic scene detection**: Uses PySceneDetect for intelligent splitting
|
| 19 |
+
- **Multiple algorithms**: Content-based, adaptive, threshold, and histogram detection
|
| 20 |
+
- **Filtering options**: Remove scenes shorter than specified duration
|
| 21 |
+
- **Customizable parameters**: Thresholds, window sizes, and detection modes
|
| 22 |
+
|
| 23 |
+
**Common options:**
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# See all available options
|
| 27 |
+
uv run python scripts/split_scenes.py --help
|
| 28 |
+
|
| 29 |
+
# Use adaptive detection with custom threshold
|
| 30 |
+
uv run python scripts/split_scenes.py video.mp4 scenes/ --detector adaptive --threshold 30.0
|
| 31 |
+
|
| 32 |
+
# Limit to maximum number of scenes
|
| 33 |
+
uv run python scripts/split_scenes.py video.mp4 scenes/ --max-scenes 50
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Automatic Video Captioning
|
| 37 |
+
|
| 38 |
+
The `scripts/caption_videos.py` script generates captions for videos (with audio) using multimodal models.
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
# Generate captions for all videos in a directory (uses Qwen2.5-Omni by default)
|
| 42 |
+
uv run python scripts/caption_videos.py videos_dir/ --output dataset.json
|
| 43 |
+
|
| 44 |
+
# Use 8-bit quantization to reduce VRAM usage
|
| 45 |
+
uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --use-8bit
|
| 46 |
+
|
| 47 |
+
# Use Gemini Flash API instead (requires API key)
|
| 48 |
+
uv run python scripts/caption_videos.py videos_dir/ --output dataset.json \
|
| 49 |
+
--captioner-type gemini_flash --api-key YOUR_API_KEY
|
| 50 |
+
|
| 51 |
+
# Caption without audio processing (video-only)
|
| 52 |
+
uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --no-audio
|
| 53 |
+
|
| 54 |
+
# Force re-caption all files
|
| 55 |
+
uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --override
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
**Key features:**
|
| 59 |
+
|
| 60 |
+
- **Audio-visual captioning**: Processes both video and audio content, including speech transcription
|
| 61 |
+
- **Multiple backends**:
|
| 62 |
+
- `qwen_omni` (default): Local Qwen2.5-Omni model - processes video + audio locally
|
| 63 |
+
- `gemini_flash`: Google Gemini Flash API - cloud-based, requires API key
|
| 64 |
+
- **Structured output**: Captions include visual description, speech transcription, sounds, and on-screen text
|
| 65 |
+
- **Memory optimization**: 8-bit quantization option for limited VRAM
|
| 66 |
+
- **Incremental processing**: Skips already-captioned files by default
|
| 67 |
+
- **Multiple output formats**: JSON, JSONL, CSV, or TXT
|
| 68 |
+
|
| 69 |
+
**Caption format:**
|
| 70 |
+
|
| 71 |
+
The captioner produces structured captions with four sections:
|
| 72 |
+
- `[VISUAL]`: Detailed description of visual content
|
| 73 |
+
- `[SPEECH]`: Word-for-word transcription of spoken content
|
| 74 |
+
- `[SOUNDS]`: Description of music, ambient sounds, sound effects
|
| 75 |
+
- `[TEXT]`: Any on-screen text visible in the video
|
| 76 |
+
|
| 77 |
+
**Environment variables (for Gemini Flash):**
|
| 78 |
+
|
| 79 |
+
Set one of these to use Gemini Flash without passing `--api-key`:
|
| 80 |
+
- `GOOGLE_API_KEY`
|
| 81 |
+
- `GEMINI_API_KEY`
|
| 82 |
+
|
| 83 |
+
### Dataset Preprocessing
|
| 84 |
+
|
| 85 |
+
The `scripts/process_dataset.py` script processes videos and caches latents for training.
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# Basic preprocessing
|
| 89 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 90 |
+
--resolution-buckets "960x544x49" \
|
| 91 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 92 |
+
--text-encoder-path /path/to/gemma-model
|
| 93 |
+
|
| 94 |
+
# With audio processing
|
| 95 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 96 |
+
--resolution-buckets "960x544x49" \
|
| 97 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 98 |
+
--text-encoder-path /path/to/gemma-model \
|
| 99 |
+
--with-audio
|
| 100 |
+
|
| 101 |
+
# With video decoding for verification
|
| 102 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 103 |
+
--resolution-buckets "960x544x49" \
|
| 104 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 105 |
+
--text-encoder-path /path/to/gemma-model \
|
| 106 |
+
--decode
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
Multiple resolution buckets can be specified, separated by `;`:
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
uv run python scripts/process_dataset.py dataset.json \
|
| 113 |
+
--resolution-buckets "960x544x49;512x512x81" \
|
| 114 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 115 |
+
--text-encoder-path /path/to/gemma-model
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
> [!NOTE]
|
| 119 |
+
> When training with multiple resolution buckets, set `optimization.batch_size: 1`.
|
| 120 |
+
|
| 121 |
+
For detailed usage, see the [Dataset Preparation Guide](dataset-preparation.md).
|
| 122 |
+
|
| 123 |
+
### Reference Video Generation
|
| 124 |
+
|
| 125 |
+
The `scripts/compute_reference.py` script provides a template for creating reference videos needed for IC-LoRA training.
|
| 126 |
+
The default implementation generates Canny edge reference videos.
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
# Generate Canny edge reference videos
|
| 130 |
+
uv run python scripts/compute_reference.py videos_dir/ --output dataset.json
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
**Key features:**
|
| 134 |
+
|
| 135 |
+
- **Canny edge detection**: Creates edge-based reference videos
|
| 136 |
+
- **In-place editing**: Updates existing dataset JSON files
|
| 137 |
+
- **Customizable**: Modify the `compute_reference()` function for different conditions (depth, pose, etc.)
|
| 138 |
+
|
| 139 |
+
> [!TIP]
|
| 140 |
+
> You can edit this script to generate other types of reference videos for IC-LoRA training,
|
| 141 |
+
> such as depth maps, segmentation masks, or any custom video transformation.
|
| 142 |
+
|
| 143 |
+
## 🔍 Debugging and Verification Scripts
|
| 144 |
+
|
| 145 |
+
### Latents Decoding
|
| 146 |
+
|
| 147 |
+
The `scripts/decode_latents.py` script decodes precomputed video latents back into video files for visual inspection.
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
# Basic usage
|
| 151 |
+
uv run python scripts/decode_latents.py /path/to/latents/dir \
|
| 152 |
+
--output-dir /path/to/output \
|
| 153 |
+
--model-path /path/to/ltx-2-model.safetensors
|
| 154 |
+
|
| 155 |
+
# With VAE tiling for large videos
|
| 156 |
+
uv run python scripts/decode_latents.py /path/to/latents/dir \
|
| 157 |
+
--output-dir /path/to/output \
|
| 158 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 159 |
+
--vae-tiling
|
| 160 |
+
|
| 161 |
+
# Decode both video and audio latents
|
| 162 |
+
uv run python scripts/decode_latents.py /path/to/latents/dir \
|
| 163 |
+
--output-dir /path/to/output \
|
| 164 |
+
--model-path /path/to/ltx-2-model.safetensors \
|
| 165 |
+
--with-audio
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
**The script will:**
|
| 169 |
+
|
| 170 |
+
1. **Load the VAE model** from the specified path
|
| 171 |
+
2. **Process all `.pt` latent files** in the input directory
|
| 172 |
+
3. **Decode each latent** back into a video using the VAE
|
| 173 |
+
4. **Save resulting videos** as MP4 files in the output directory
|
| 174 |
+
|
| 175 |
+
**When to use:**
|
| 176 |
+
|
| 177 |
+
- **Verify preprocessing quality**: Check that your videos were encoded correctly
|
| 178 |
+
- **Debug training data**: Visualize what the model actually sees during training
|
| 179 |
+
- **Quality assessment**: Ensure latent encoding preserves important visual details
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
### Inference Script
|
| 183 |
+
|
| 184 |
+
The `scripts/inference.py` script runs inference with a trained model.
|
| 185 |
+
|
| 186 |
+
> [!TIP]
|
| 187 |
+
> For production inference, consider using the [`ltx-pipelines`](../../ltx-pipelines/) package which provides optimized,
|
| 188 |
+
> feature-rich pipelines for various use cases:
|
| 189 |
+
> - **Text/Image-to-Video**: `TI2VidOneStagePipeline`, `TI2VidTwoStagesPipeline`
|
| 190 |
+
> - **Distilled (fast) inference**: `DistilledPipeline`
|
| 191 |
+
> - **IC-LoRA video-to-video**: `ICLoraPipeline`
|
| 192 |
+
> - **Keyframe interpolation**: `KeyframeInterpolationPipeline`
|
| 193 |
+
>
|
| 194 |
+
> All pipelines support loading custom LoRAs trained with this trainer.
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
# Text-to-video inference (with audio by default)
|
| 198 |
+
# By default, uses CFG scale 4.0 and STG scale 1.0 with block 29
|
| 199 |
+
uv run python scripts/inference.py \
|
| 200 |
+
--checkpoint /path/to/model.safetensors \
|
| 201 |
+
--text-encoder-path /path/to/gemma \
|
| 202 |
+
--prompt "A cat playing with a ball" \
|
| 203 |
+
--output output.mp4
|
| 204 |
+
|
| 205 |
+
# Video-only (skip audio generation)
|
| 206 |
+
uv run python scripts/inference.py \
|
| 207 |
+
--checkpoint /path/to/model.safetensors \
|
| 208 |
+
--text-encoder-path /path/to/gemma \
|
| 209 |
+
--prompt "A cat playing with a ball" \
|
| 210 |
+
--skip-audio \
|
| 211 |
+
--output output.mp4
|
| 212 |
+
|
| 213 |
+
# Image-to-video with conditioning image
|
| 214 |
+
uv run python scripts/inference.py \
|
| 215 |
+
--checkpoint /path/to/model.safetensors \
|
| 216 |
+
--text-encoder-path /path/to/gemma \
|
| 217 |
+
--prompt "A cat walking" \
|
| 218 |
+
--condition-image first_frame.png \
|
| 219 |
+
--output output.mp4
|
| 220 |
+
|
| 221 |
+
# Custom guidance settings
|
| 222 |
+
uv run python scripts/inference.py \
|
| 223 |
+
--checkpoint /path/to/model.safetensors \
|
| 224 |
+
--text-encoder-path /path/to/gemma \
|
| 225 |
+
--prompt "A cat playing with a ball" \
|
| 226 |
+
--guidance-scale 4.0 \
|
| 227 |
+
--stg-scale 1.0 \
|
| 228 |
+
--stg-blocks 29 \
|
| 229 |
+
--output output.mp4
|
| 230 |
+
|
| 231 |
+
# Disable STG (CFG only)
|
| 232 |
+
uv run python scripts/inference.py \
|
| 233 |
+
--checkpoint /path/to/model.safetensors \
|
| 234 |
+
--text-encoder-path /path/to/gemma \
|
| 235 |
+
--prompt "A cat playing with a ball" \
|
| 236 |
+
--stg-scale 0.0 \
|
| 237 |
+
--output output.mp4
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
**Guidance parameters:**
|
| 241 |
+
|
| 242 |
+
| Parameter | Default | Description |
|
| 243 |
+
|-----------|---------|-------------|
|
| 244 |
+
| `--guidance-scale` | 4.0 | CFG (Classifier-Free Guidance) scale |
|
| 245 |
+
| `--stg-scale` | 1.0 | STG (Spatio-Temporal Guidance) scale. 0.0 disables STG |
|
| 246 |
+
| `--stg-blocks` | 29 | Transformer block(s) to perturb for STG |
|
| 247 |
+
| `--stg-mode` | stg_av | `stg_av` perturbs both audio and video, `stg_v` video only |
|
| 248 |
+
|
| 249 |
+
## 🚀 Training Scripts
|
| 250 |
+
|
| 251 |
+
### Basic and Distributed Training
|
| 252 |
+
|
| 253 |
+
Use `scripts/train.py` for both single GPU and multi-GPU runs:
|
| 254 |
+
|
| 255 |
+
```bash
|
| 256 |
+
# Single-GPU training
|
| 257 |
+
uv run python scripts/train.py configs/ltx2_av_lora.yaml
|
| 258 |
+
|
| 259 |
+
# Multi-GPU (uses your accelerate config)
|
| 260 |
+
uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
|
| 261 |
+
|
| 262 |
+
# Override number of processes
|
| 263 |
+
uv run accelerate launch --num_processes 4 scripts/train.py configs/ltx2_av_lora.yaml
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
For detailed usage, see the [Training Guide](training-guide.md).
|
| 267 |
+
|
| 268 |
+
## 💡 Tips for Using Utility Scripts
|
| 269 |
+
|
| 270 |
+
- **Start with `--help`**: Always check available options for each script
|
| 271 |
+
- **Test on small datasets**: Verify workflows with a few files before processing large datasets
|
| 272 |
+
- **Use decode verification**: Always decode a few samples to verify preprocessing quality
|
| 273 |
+
- **Monitor VRAM usage**: Use `--use-8bit` or quantization flags when running into memory issues
|
| 274 |
+
- **Keep backups**: Make copies of important dataset files before running conversion scripts
|
packages/ltx-trainer/scripts/caption_videos.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Auto-caption videos with audio using multimodal models.
|
| 5 |
+
This script provides a command-line interface for generating captions for videos
|
| 6 |
+
(including audio) using multimodal models. It supports:
|
| 7 |
+
- Qwen2.5-Omni: Local model for audio-visual captioning (default)
|
| 8 |
+
- Gemini Flash: Cloud-based API for audio-visual captioning
|
| 9 |
+
The paths to videos in the generated dataset/captions file will be RELATIVE to the
|
| 10 |
+
directory where the output file is stored. This makes the dataset more portable and
|
| 11 |
+
easier to use in different environments.
|
| 12 |
+
Basic usage:
|
| 13 |
+
# Caption a single video (includes audio by default)
|
| 14 |
+
caption_videos.py video.mp4 --output captions.json
|
| 15 |
+
# Caption all videos in a directory
|
| 16 |
+
caption_videos.py videos_dir/ --output captions.csv
|
| 17 |
+
# Caption with custom instruction
|
| 18 |
+
caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail."
|
| 19 |
+
Advanced usage:
|
| 20 |
+
# Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var)
|
| 21 |
+
caption_videos.py videos_dir/ --captioner-type gemini_flash
|
| 22 |
+
# Disable audio processing (video-only captions)
|
| 23 |
+
caption_videos.py videos_dir/ --no-audio
|
| 24 |
+
# Process videos with specific extensions and save as JSON
|
| 25 |
+
caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import csv
|
| 29 |
+
import json
|
| 30 |
+
from enum import Enum
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import typer
|
| 35 |
+
from rich.console import Console
|
| 36 |
+
from rich.progress import (
|
| 37 |
+
BarColumn,
|
| 38 |
+
MofNCompleteColumn,
|
| 39 |
+
Progress,
|
| 40 |
+
SpinnerColumn,
|
| 41 |
+
TextColumn,
|
| 42 |
+
TimeElapsedColumn,
|
| 43 |
+
TimeRemainingColumn,
|
| 44 |
+
)
|
| 45 |
+
from transformers.utils.logging import disable_progress_bar
|
| 46 |
+
|
| 47 |
+
from ltx_trainer.captioning import CaptionerType, MediaCaptioningModel, create_captioner
|
| 48 |
+
|
| 49 |
+
VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"]
|
| 50 |
+
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"]
|
| 51 |
+
MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS
|
| 52 |
+
SAVE_INTERVAL = 5
|
| 53 |
+
|
| 54 |
+
console = Console()
|
| 55 |
+
app = typer.Typer(
|
| 56 |
+
pretty_exceptions_enable=False,
|
| 57 |
+
no_args_is_help=True,
|
| 58 |
+
help="Auto-caption videos with audio using multimodal models.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
disable_progress_bar()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class OutputFormat(str, Enum):
|
| 65 |
+
"""Available output formats for captions."""
|
| 66 |
+
|
| 67 |
+
TXT = "txt" # Separate files for captions and video paths, one caption / video path per line
|
| 68 |
+
CSV = "csv" # CSV file with video path and caption columns
|
| 69 |
+
JSON = "json" # JSON file with video paths as keys and captions as values
|
| 70 |
+
JSONL = "jsonl" # JSON Lines file with one JSON object per line
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def caption_media(
|
| 74 |
+
input_path: Path,
|
| 75 |
+
output_path: Path,
|
| 76 |
+
captioner: MediaCaptioningModel,
|
| 77 |
+
extensions: list[str],
|
| 78 |
+
recursive: bool,
|
| 79 |
+
fps: int,
|
| 80 |
+
include_audio: bool,
|
| 81 |
+
clean_caption: bool,
|
| 82 |
+
output_format: OutputFormat,
|
| 83 |
+
override: bool,
|
| 84 |
+
) -> None:
|
| 85 |
+
"""Caption videos and images using the provided captioning model.
|
| 86 |
+
Args:
|
| 87 |
+
input_path: Path to input video file or directory
|
| 88 |
+
output_path: Path to output caption file
|
| 89 |
+
captioner: Media captioning model
|
| 90 |
+
extensions: List of media file extensions to include
|
| 91 |
+
recursive: Whether to search subdirectories recursively
|
| 92 |
+
fps: Frames per second to sample from videos (ignored for images)
|
| 93 |
+
include_audio: Whether to include audio in captioning
|
| 94 |
+
clean_caption: Whether to clean up captions
|
| 95 |
+
output_format: Format to save the captions in
|
| 96 |
+
override: Whether to override existing captions
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# Get list of media files to process
|
| 100 |
+
media_files = _get_media_files(input_path, extensions, recursive)
|
| 101 |
+
|
| 102 |
+
if not media_files:
|
| 103 |
+
console.print("[bold yellow]No media files found to process.[/]")
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
console.print(f"Found [bold]{len(media_files)}[/] media files to process.")
|
| 107 |
+
|
| 108 |
+
# Load existing captions and determine which files need processing
|
| 109 |
+
base_dir = output_path.parent.resolve()
|
| 110 |
+
existing_captions = _load_existing_captions(output_path, output_format)
|
| 111 |
+
existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions}
|
| 112 |
+
|
| 113 |
+
if override:
|
| 114 |
+
media_to_process = media_files
|
| 115 |
+
else:
|
| 116 |
+
media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths]
|
| 117 |
+
if skipped := len(media_files) - len(media_to_process):
|
| 118 |
+
console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]")
|
| 119 |
+
|
| 120 |
+
if not media_to_process:
|
| 121 |
+
console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]")
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
# Process media files
|
| 125 |
+
captions = existing_captions.copy()
|
| 126 |
+
successfully_captioned = 0
|
| 127 |
+
progress = Progress(
|
| 128 |
+
SpinnerColumn(),
|
| 129 |
+
TextColumn("{task.description}"),
|
| 130 |
+
BarColumn(bar_width=40),
|
| 131 |
+
MofNCompleteColumn(),
|
| 132 |
+
TimeElapsedColumn(),
|
| 133 |
+
TextColumn("•"),
|
| 134 |
+
TimeRemainingColumn(),
|
| 135 |
+
console=console,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
with progress:
|
| 139 |
+
task = progress.add_task("Captioning", total=len(media_to_process))
|
| 140 |
+
|
| 141 |
+
for i, media_file in enumerate(media_to_process):
|
| 142 |
+
progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]")
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
# Generate caption for the media
|
| 146 |
+
caption = captioner.caption(
|
| 147 |
+
path=media_file,
|
| 148 |
+
fps=fps,
|
| 149 |
+
include_audio=include_audio,
|
| 150 |
+
clean_caption=clean_caption,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Convert absolute path to relative path (relative to the output file's directory)
|
| 154 |
+
rel_path = str(media_file.resolve().relative_to(base_dir))
|
| 155 |
+
# Store the caption with the relative path as key
|
| 156 |
+
captions[rel_path] = caption
|
| 157 |
+
successfully_captioned += 1
|
| 158 |
+
except Exception as e:
|
| 159 |
+
console.print(f"[bold red]Error captioning {media_file}: {e}[/]")
|
| 160 |
+
|
| 161 |
+
if i % SAVE_INTERVAL == 0:
|
| 162 |
+
_save_captions(captions, output_path, output_format)
|
| 163 |
+
|
| 164 |
+
# Advance progress bar
|
| 165 |
+
progress.advance(task)
|
| 166 |
+
|
| 167 |
+
# Save captions to file
|
| 168 |
+
_save_captions(captions, output_path, output_format)
|
| 169 |
+
|
| 170 |
+
# Print summary
|
| 171 |
+
console.print(
|
| 172 |
+
f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _get_media_files(
|
| 177 |
+
input_path: Path,
|
| 178 |
+
extensions: list[str] = MEDIA_EXTENSIONS,
|
| 179 |
+
recursive: bool = False,
|
| 180 |
+
) -> list[Path]:
|
| 181 |
+
"""Get all media files from the input path."""
|
| 182 |
+
input_path = Path(input_path)
|
| 183 |
+
# Normalize extensions to lowercase without dots
|
| 184 |
+
extensions_set = {ext.lower().lstrip(".") for ext in extensions}
|
| 185 |
+
|
| 186 |
+
if input_path.is_file():
|
| 187 |
+
# If input is a file, check if it has a valid extension
|
| 188 |
+
if input_path.suffix.lstrip(".").lower() in extensions_set:
|
| 189 |
+
return [input_path]
|
| 190 |
+
else:
|
| 191 |
+
typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.")
|
| 192 |
+
return []
|
| 193 |
+
elif input_path.is_dir():
|
| 194 |
+
# Find all files and filter by extension case-insensitively
|
| 195 |
+
glob_pattern = "**/*" if recursive else "*"
|
| 196 |
+
media_files = [
|
| 197 |
+
f for f in input_path.glob(glob_pattern) if f.is_file() and f.suffix.lstrip(".").lower() in extensions_set
|
| 198 |
+
]
|
| 199 |
+
return sorted(media_files)
|
| 200 |
+
else:
|
| 201 |
+
typer.echo(f"Error: {input_path} does not exist.")
|
| 202 |
+
raise typer.Exit(code=1)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _save_captions(
|
| 206 |
+
captions: dict[str, str],
|
| 207 |
+
output_path: Path,
|
| 208 |
+
format_type: OutputFormat,
|
| 209 |
+
) -> None:
|
| 210 |
+
"""Save captions to a file in the specified format.
|
| 211 |
+
Args:
|
| 212 |
+
captions: Dictionary mapping media paths to captions
|
| 213 |
+
output_path: Path to save the output file
|
| 214 |
+
format_type: Format to save the captions in
|
| 215 |
+
"""
|
| 216 |
+
# Create parent directories if they don't exist
|
| 217 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
|
| 219 |
+
console.print("[bold blue]Saving captions...[/]")
|
| 220 |
+
|
| 221 |
+
match format_type:
|
| 222 |
+
case OutputFormat.TXT:
|
| 223 |
+
# Create two separate files for captions and media paths
|
| 224 |
+
captions_file = output_path.with_stem(f"{output_path.stem}_captions")
|
| 225 |
+
paths_file = output_path.with_stem(f"{output_path.stem}_paths")
|
| 226 |
+
|
| 227 |
+
with captions_file.open("w", encoding="utf-8") as f:
|
| 228 |
+
for caption in captions.values():
|
| 229 |
+
f.write(f"{caption}\n")
|
| 230 |
+
|
| 231 |
+
with paths_file.open("w", encoding="utf-8") as f:
|
| 232 |
+
for media_path in captions:
|
| 233 |
+
f.write(f"{media_path}\n")
|
| 234 |
+
|
| 235 |
+
console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]")
|
| 236 |
+
console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]")
|
| 237 |
+
|
| 238 |
+
case OutputFormat.CSV:
|
| 239 |
+
with output_path.open("w", encoding="utf-8", newline="") as f:
|
| 240 |
+
writer = csv.writer(f)
|
| 241 |
+
writer.writerow(["caption", "media_path"])
|
| 242 |
+
for media_path, caption in captions.items():
|
| 243 |
+
writer.writerow([caption, media_path])
|
| 244 |
+
|
| 245 |
+
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
|
| 246 |
+
|
| 247 |
+
case OutputFormat.JSON:
|
| 248 |
+
# Format as list of dictionaries with caption and media_path keys
|
| 249 |
+
json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()]
|
| 250 |
+
|
| 251 |
+
with output_path.open("w", encoding="utf-8") as f:
|
| 252 |
+
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
| 253 |
+
|
| 254 |
+
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
|
| 255 |
+
|
| 256 |
+
case OutputFormat.JSONL:
|
| 257 |
+
with output_path.open("w", encoding="utf-8") as f:
|
| 258 |
+
for media_path, caption in captions.items():
|
| 259 |
+
f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n")
|
| 260 |
+
|
| 261 |
+
console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
|
| 262 |
+
|
| 263 |
+
case _:
|
| 264 |
+
raise ValueError(f"Unsupported output format: {format_type}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _load_existing_captions( # noqa: PLR0912
|
| 268 |
+
output_path: Path,
|
| 269 |
+
format_type: OutputFormat,
|
| 270 |
+
) -> dict[str, str]:
|
| 271 |
+
"""Load existing captions from a file.
|
| 272 |
+
Args:
|
| 273 |
+
output_path: Path to the captions file
|
| 274 |
+
format_type: Format of the captions file
|
| 275 |
+
Returns:
|
| 276 |
+
Dictionary mapping media paths to captions, or empty dict if file doesn't exist
|
| 277 |
+
"""
|
| 278 |
+
if not output_path.exists():
|
| 279 |
+
return {}
|
| 280 |
+
|
| 281 |
+
console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]")
|
| 282 |
+
|
| 283 |
+
existing_captions = {}
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
match format_type:
|
| 287 |
+
case OutputFormat.TXT:
|
| 288 |
+
# For TXT format, we have two separate files
|
| 289 |
+
captions_file = output_path.with_stem(f"{output_path.stem}_captions")
|
| 290 |
+
paths_file = output_path.with_stem(f"{output_path.stem}_paths")
|
| 291 |
+
|
| 292 |
+
if captions_file.exists() and paths_file.exists():
|
| 293 |
+
captions = captions_file.read_text(encoding="utf-8").splitlines()
|
| 294 |
+
paths = paths_file.read_text(encoding="utf-8").splitlines()
|
| 295 |
+
|
| 296 |
+
if len(captions) == len(paths):
|
| 297 |
+
existing_captions = dict(zip(paths, captions, strict=False))
|
| 298 |
+
|
| 299 |
+
case OutputFormat.CSV:
|
| 300 |
+
with output_path.open("r", encoding="utf-8", newline="") as f:
|
| 301 |
+
reader = csv.reader(f)
|
| 302 |
+
# Skip header
|
| 303 |
+
next(reader, None)
|
| 304 |
+
for row in reader:
|
| 305 |
+
if len(row) >= 2:
|
| 306 |
+
caption, media_path = row[0], row[1]
|
| 307 |
+
existing_captions[media_path] = caption
|
| 308 |
+
|
| 309 |
+
case OutputFormat.JSON:
|
| 310 |
+
with output_path.open("r", encoding="utf-8") as f:
|
| 311 |
+
json_data = json.load(f)
|
| 312 |
+
for item in json_data:
|
| 313 |
+
if "caption" in item and "media_path" in item:
|
| 314 |
+
existing_captions[item["media_path"]] = item["caption"]
|
| 315 |
+
|
| 316 |
+
case OutputFormat.JSONL:
|
| 317 |
+
with output_path.open("r", encoding="utf-8") as f:
|
| 318 |
+
for line in f:
|
| 319 |
+
item = json.loads(line)
|
| 320 |
+
if "caption" in item and "media_path" in item:
|
| 321 |
+
existing_captions[item["media_path"]] = item["caption"]
|
| 322 |
+
|
| 323 |
+
case _:
|
| 324 |
+
raise ValueError(f"Unsupported output format: {format_type}")
|
| 325 |
+
|
| 326 |
+
console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions")
|
| 327 |
+
return existing_captions
|
| 328 |
+
|
| 329 |
+
except Exception as e:
|
| 330 |
+
console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]")
|
| 331 |
+
return {}
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@app.command()
|
| 335 |
+
def main( # noqa: PLR0913
|
| 336 |
+
input_path: Path = typer.Argument( # noqa: B008
|
| 337 |
+
...,
|
| 338 |
+
help="Path to input video/image file or directory containing media files",
|
| 339 |
+
exists=True,
|
| 340 |
+
),
|
| 341 |
+
output: Path | None = typer.Option( # noqa: B008
|
| 342 |
+
None,
|
| 343 |
+
"--output",
|
| 344 |
+
"-o",
|
| 345 |
+
help="Path to output file for captions. Format determined by file extension.",
|
| 346 |
+
),
|
| 347 |
+
captioner_type: CaptionerType = typer.Option( # noqa: B008
|
| 348 |
+
CaptionerType.QWEN_OMNI,
|
| 349 |
+
"--captioner-type",
|
| 350 |
+
"-c",
|
| 351 |
+
help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)",
|
| 352 |
+
case_sensitive=False,
|
| 353 |
+
),
|
| 354 |
+
device: str | None = typer.Option(
|
| 355 |
+
None,
|
| 356 |
+
"--device",
|
| 357 |
+
"-d",
|
| 358 |
+
help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.",
|
| 359 |
+
),
|
| 360 |
+
use_8bit: bool = typer.Option(
|
| 361 |
+
False,
|
| 362 |
+
"--use-8bit",
|
| 363 |
+
help="Whether to use 8-bit precision for the captioning model (reduces memory usage)",
|
| 364 |
+
),
|
| 365 |
+
instruction: str | None = typer.Option(
|
| 366 |
+
None,
|
| 367 |
+
"--instruction",
|
| 368 |
+
"-i",
|
| 369 |
+
help="Custom instruction for the captioning model. If not provided, uses an appropriate default.",
|
| 370 |
+
),
|
| 371 |
+
extensions: str = typer.Option(
|
| 372 |
+
",".join(MEDIA_EXTENSIONS),
|
| 373 |
+
"--extensions",
|
| 374 |
+
"-e",
|
| 375 |
+
help="Comma-separated list of media file extensions to process",
|
| 376 |
+
),
|
| 377 |
+
recursive: bool = typer.Option(
|
| 378 |
+
False,
|
| 379 |
+
"--recursive",
|
| 380 |
+
"-r",
|
| 381 |
+
help="Search for media files in subdirectories recursively",
|
| 382 |
+
),
|
| 383 |
+
fps: int = typer.Option(
|
| 384 |
+
3,
|
| 385 |
+
"--fps",
|
| 386 |
+
"-f",
|
| 387 |
+
help="Frames per second to sample from videos (ignored for images)",
|
| 388 |
+
),
|
| 389 |
+
include_audio: bool = typer.Option(
|
| 390 |
+
True,
|
| 391 |
+
"--audio/--no-audio",
|
| 392 |
+
help="Whether to include audio in captioning (for videos with audio tracks)",
|
| 393 |
+
),
|
| 394 |
+
clean_caption: bool = typer.Option(
|
| 395 |
+
True,
|
| 396 |
+
"--clean-caption/--raw-caption",
|
| 397 |
+
help="Whether to clean up captions by removing common VLM patterns",
|
| 398 |
+
),
|
| 399 |
+
override: bool = typer.Option(
|
| 400 |
+
False,
|
| 401 |
+
"--override",
|
| 402 |
+
help="Whether to override existing captions for media",
|
| 403 |
+
),
|
| 404 |
+
api_key: str | None = typer.Option(
|
| 405 |
+
None,
|
| 406 |
+
"--api-key",
|
| 407 |
+
envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"],
|
| 408 |
+
help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)",
|
| 409 |
+
),
|
| 410 |
+
) -> None:
|
| 411 |
+
"""Auto-caption videos with audio using multimodal models.
|
| 412 |
+
This script supports audio-visual captioning using:
|
| 413 |
+
- Qwen2.5-Omni: Local model (default) - processes both video and audio
|
| 414 |
+
- Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable
|
| 415 |
+
The paths in the output file will be relative to the output file's directory.
|
| 416 |
+
Examples:
|
| 417 |
+
# Caption videos with audio using Qwen2.5-Omni (default)
|
| 418 |
+
caption_videos.py videos_dir/ -o captions.json
|
| 419 |
+
# Caption using Gemini Flash API
|
| 420 |
+
caption_videos.py videos_dir/ -o captions.json -c gemini_flash
|
| 421 |
+
# Caption without audio (video-only)
|
| 422 |
+
caption_videos.py videos_dir/ -o captions.json --no-audio
|
| 423 |
+
# Caption with custom instruction
|
| 424 |
+
caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail"
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
# Determine device for local models
|
| 428 |
+
device_str = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 429 |
+
|
| 430 |
+
# Parse extensions
|
| 431 |
+
ext_list = [ext.strip() for ext in extensions.split(",")]
|
| 432 |
+
|
| 433 |
+
# Determine output path and format
|
| 434 |
+
if output is None:
|
| 435 |
+
output_format = OutputFormat.JSON
|
| 436 |
+
if input_path.is_file(): # noqa: SIM108
|
| 437 |
+
# Default to a JSON file with the same name as the input media
|
| 438 |
+
output = input_path.with_suffix(".dataset.json")
|
| 439 |
+
else:
|
| 440 |
+
# Default to a JSON file in the input directory
|
| 441 |
+
output = input_path / "dataset.json"
|
| 442 |
+
else:
|
| 443 |
+
# Determine format from file extension
|
| 444 |
+
output_format = OutputFormat(Path(output).suffix.lstrip(".").lower())
|
| 445 |
+
|
| 446 |
+
# Ensure output path is absolute
|
| 447 |
+
output = Path(output).resolve()
|
| 448 |
+
console.print(f"Output will be saved to [bold blue]{output}[/]")
|
| 449 |
+
|
| 450 |
+
# Initialize captioning model
|
| 451 |
+
with console.status("Loading captioning model...", spinner="dots"):
|
| 452 |
+
if captioner_type == CaptionerType.QWEN_OMNI:
|
| 453 |
+
captioner = create_captioner(
|
| 454 |
+
captioner_type=captioner_type,
|
| 455 |
+
device=device_str,
|
| 456 |
+
use_8bit=use_8bit,
|
| 457 |
+
instruction=instruction,
|
| 458 |
+
)
|
| 459 |
+
elif captioner_type == CaptionerType.GEMINI_FLASH:
|
| 460 |
+
captioner = create_captioner(
|
| 461 |
+
captioner_type=captioner_type,
|
| 462 |
+
api_key=api_key,
|
| 463 |
+
instruction=instruction,
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
raise ValueError(f"Unsupported captioner type: {captioner_type}")
|
| 467 |
+
|
| 468 |
+
console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully")
|
| 469 |
+
|
| 470 |
+
# Caption media files
|
| 471 |
+
caption_media(
|
| 472 |
+
input_path=input_path,
|
| 473 |
+
output_path=output,
|
| 474 |
+
captioner=captioner,
|
| 475 |
+
extensions=ext_list,
|
| 476 |
+
recursive=recursive,
|
| 477 |
+
fps=fps,
|
| 478 |
+
include_audio=include_audio,
|
| 479 |
+
clean_caption=clean_caption,
|
| 480 |
+
output_format=output_format,
|
| 481 |
+
override=override,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
if __name__ == "__main__":
|
| 486 |
+
app()
|
packages/ltx-trainer/scripts/compute_reference.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Compute reference videos for IC-LoRA training.
|
| 3 |
+
This script provides a command-line interface for generating reference videos to be used for IC-LoRA training.
|
| 4 |
+
Note that it reads and writes to the same file (the output of caption_videos.py),
|
| 5 |
+
where it adds the "reference_path" field to the JSON.
|
| 6 |
+
Basic usage:
|
| 7 |
+
# Compute reference videos for all videos in a directory
|
| 8 |
+
compute_reference.py videos_dir/ --output videos_dir/captions.json
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# Standard library imports
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict
|
| 15 |
+
|
| 16 |
+
# Third-party imports
|
| 17 |
+
import cv2
|
| 18 |
+
import torch
|
| 19 |
+
import torchvision.transforms.functional as TF # noqa: N812
|
| 20 |
+
import typer
|
| 21 |
+
from rich.console import Console
|
| 22 |
+
from rich.progress import (
|
| 23 |
+
BarColumn,
|
| 24 |
+
MofNCompleteColumn,
|
| 25 |
+
Progress,
|
| 26 |
+
SpinnerColumn,
|
| 27 |
+
TextColumn,
|
| 28 |
+
TimeElapsedColumn,
|
| 29 |
+
TimeRemainingColumn,
|
| 30 |
+
)
|
| 31 |
+
from transformers.utils.logging import disable_progress_bar
|
| 32 |
+
|
| 33 |
+
# Local imports
|
| 34 |
+
from ltx_trainer.video_utils import read_video, save_video
|
| 35 |
+
|
| 36 |
+
# Initialize console and disable progress bars
|
| 37 |
+
console = Console()
|
| 38 |
+
disable_progress_bar()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_reference(
|
| 42 |
+
images: torch.Tensor,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""Compute Canny edge detection on a batch of images.
|
| 45 |
+
Args:
|
| 46 |
+
images: Batch of images tensor of shape [B, C, H, W]
|
| 47 |
+
Returns:
|
| 48 |
+
Binary edge masks tensor of shape [B, H, W]
|
| 49 |
+
"""
|
| 50 |
+
# Convert to grayscale if needed
|
| 51 |
+
if images.shape[1] == 3:
|
| 52 |
+
images = TF.rgb_to_grayscale(images)
|
| 53 |
+
|
| 54 |
+
# Ensure images are in [0, 1] range
|
| 55 |
+
if images.max() > 1.0:
|
| 56 |
+
images = images / 255.0
|
| 57 |
+
|
| 58 |
+
# Compute Canny edges
|
| 59 |
+
edge_masks = []
|
| 60 |
+
for image in images:
|
| 61 |
+
# Convert to numpy for OpenCV
|
| 62 |
+
image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8")
|
| 63 |
+
|
| 64 |
+
# Apply Canny edge detection
|
| 65 |
+
edges = cv2.Canny(
|
| 66 |
+
image_np,
|
| 67 |
+
threshold1=100,
|
| 68 |
+
threshold2=200,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Convert back to tensor
|
| 72 |
+
edge_mask = torch.from_numpy(edges).float()
|
| 73 |
+
edge_masks.append(edge_mask)
|
| 74 |
+
|
| 75 |
+
edges = torch.stack(edge_masks)
|
| 76 |
+
edges = torch.stack([edges] * 3, dim=1) # Convert to 3-channel
|
| 77 |
+
return edges
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _get_meta_data(
|
| 81 |
+
output_path: Path,
|
| 82 |
+
) -> Dict[str, str]:
|
| 83 |
+
"""Get set of existing reference video paths without loading the actual files.
|
| 84 |
+
Args:
|
| 85 |
+
output_path: Path to the reference video paths file
|
| 86 |
+
Returns:
|
| 87 |
+
Dictionary mapping media paths to reference video paths
|
| 88 |
+
"""
|
| 89 |
+
if not output_path.exists():
|
| 90 |
+
return {}
|
| 91 |
+
|
| 92 |
+
console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]")
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
with output_path.open("r", encoding="utf-8") as f:
|
| 96 |
+
json_data = json.load(f)
|
| 97 |
+
return json_data
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]")
|
| 101 |
+
return {}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _save_dataset_json(
|
| 105 |
+
reference_paths: Dict[str, str],
|
| 106 |
+
output_path: Path,
|
| 107 |
+
) -> None:
|
| 108 |
+
"""Save dataset json with reference video paths.
|
| 109 |
+
Args:
|
| 110 |
+
reference_paths: Dictionary mapping media paths to reference video paths
|
| 111 |
+
output_path: Path to save the output file
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
with output_path.open("r", encoding="utf-8") as f:
|
| 115 |
+
json_data = json.load(f)
|
| 116 |
+
new_json_data = json_data.copy()
|
| 117 |
+
for i, item in enumerate(json_data):
|
| 118 |
+
media_path = item["media_path"]
|
| 119 |
+
reference_path = reference_paths[media_path]
|
| 120 |
+
new_json_data[i]["reference_path"] = reference_path
|
| 121 |
+
|
| 122 |
+
with output_path.open("w", encoding="utf-8") as f:
|
| 123 |
+
json.dump(new_json_data, f, indent=2, ensure_ascii=False)
|
| 124 |
+
|
| 125 |
+
console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]")
|
| 126 |
+
console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:")
|
| 127 |
+
console.print(" reference_column='[cyan]reference_path[/]'")
|
| 128 |
+
console.print(" video_column='[cyan]media_path[/]'")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def process_media(
|
| 132 |
+
input_path: Path,
|
| 133 |
+
output_path: Path,
|
| 134 |
+
override: bool,
|
| 135 |
+
batch_size: int = 100,
|
| 136 |
+
) -> None:
|
| 137 |
+
"""Process videos and images to compute condition on videos.
|
| 138 |
+
Args:
|
| 139 |
+
input_path: Path to input video/image file or directory
|
| 140 |
+
output_path: Path to output reference video file
|
| 141 |
+
override: Whether to override existing reference video files
|
| 142 |
+
"""
|
| 143 |
+
if not output_path.exists():
|
| 144 |
+
raise FileNotFoundError(
|
| 145 |
+
f"Output file does not exist: {output_path}. This is also the input file for the dataset."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Check for existing reference video files
|
| 149 |
+
meta_data = _get_meta_data(output_path)
|
| 150 |
+
|
| 151 |
+
base_dir = input_path.resolve()
|
| 152 |
+
console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths")
|
| 153 |
+
|
| 154 |
+
# Filter media files
|
| 155 |
+
media_to_process = []
|
| 156 |
+
skipped_media = []
|
| 157 |
+
|
| 158 |
+
def media_path_to_reference_path(media_file: Path) -> Path:
|
| 159 |
+
return media_file.parent / (media_file.stem + "_reference" + media_file.suffix)
|
| 160 |
+
|
| 161 |
+
media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data]
|
| 162 |
+
for media_file in media_files:
|
| 163 |
+
reference_path = media_path_to_reference_path(media_file)
|
| 164 |
+
media_to_process.append(media_file)
|
| 165 |
+
|
| 166 |
+
console.print(f"Processing [bold]{len(media_to_process)}[/] media.")
|
| 167 |
+
|
| 168 |
+
# Initialize progress tracking
|
| 169 |
+
progress = Progress(
|
| 170 |
+
SpinnerColumn(),
|
| 171 |
+
TextColumn("{task.description}"),
|
| 172 |
+
BarColumn(bar_width=40),
|
| 173 |
+
MofNCompleteColumn(),
|
| 174 |
+
TimeElapsedColumn(),
|
| 175 |
+
TextColumn("•"),
|
| 176 |
+
TimeRemainingColumn(),
|
| 177 |
+
console=console,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Process media files
|
| 181 |
+
media_paths = [item["media_path"] for item in meta_data]
|
| 182 |
+
reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths}
|
| 183 |
+
|
| 184 |
+
with progress:
|
| 185 |
+
task = progress.add_task("Computing condition on videos", total=len(media_to_process))
|
| 186 |
+
|
| 187 |
+
for media_file in media_to_process:
|
| 188 |
+
progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]")
|
| 189 |
+
|
| 190 |
+
rel_path = str(media_file.resolve().relative_to(base_dir))
|
| 191 |
+
reference_path = media_path_to_reference_path(media_file)
|
| 192 |
+
reference_paths[rel_path] = str(reference_path.relative_to(base_dir))
|
| 193 |
+
|
| 194 |
+
if not reference_path.resolve().exists() or override:
|
| 195 |
+
try:
|
| 196 |
+
video, fps = read_video(media_file)
|
| 197 |
+
|
| 198 |
+
# Process frames in batches
|
| 199 |
+
condition_frames = []
|
| 200 |
+
|
| 201 |
+
for i in range(0, len(video), batch_size):
|
| 202 |
+
batch = video[i : i + batch_size]
|
| 203 |
+
condition_batch = compute_reference(batch)
|
| 204 |
+
condition_frames.append(condition_batch)
|
| 205 |
+
|
| 206 |
+
# Concatenate all edge frames
|
| 207 |
+
all_condition = torch.cat(condition_frames, dim=0)
|
| 208 |
+
|
| 209 |
+
# Save the edge video
|
| 210 |
+
save_video(all_condition, reference_path.resolve(), fps=fps)
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]")
|
| 214 |
+
reference_paths.pop(rel_path)
|
| 215 |
+
else:
|
| 216 |
+
skipped_media.append(media_file)
|
| 217 |
+
|
| 218 |
+
progress.advance(task)
|
| 219 |
+
|
| 220 |
+
# Save results
|
| 221 |
+
_save_dataset_json(reference_paths, output_path)
|
| 222 |
+
|
| 223 |
+
# Print summary
|
| 224 |
+
total_to_process = len(media_files) - len(skipped_media)
|
| 225 |
+
console.print(
|
| 226 |
+
f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
app = typer.Typer(
|
| 231 |
+
pretty_exceptions_enable=False,
|
| 232 |
+
no_args_is_help=True,
|
| 233 |
+
help="Compute reference videos for IC-LoRA training.",
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@app.command()
|
| 238 |
+
def main(
|
| 239 |
+
input_path: Path = typer.Argument( # noqa: B008
|
| 240 |
+
...,
|
| 241 |
+
help="Path to input video/image file or directory containing media files",
|
| 242 |
+
exists=True,
|
| 243 |
+
),
|
| 244 |
+
output: Path | None = typer.Option( # noqa: B008
|
| 245 |
+
None,
|
| 246 |
+
"--output",
|
| 247 |
+
"-o",
|
| 248 |
+
help="Path to json output file for reference video paths. "
|
| 249 |
+
"This is also the input file for the dataset, the output of compute_captions.py.",
|
| 250 |
+
),
|
| 251 |
+
override: bool = typer.Option(
|
| 252 |
+
False,
|
| 253 |
+
"--override",
|
| 254 |
+
help="Whether to override existing reference video files",
|
| 255 |
+
),
|
| 256 |
+
batch_size: int = typer.Option(
|
| 257 |
+
100,
|
| 258 |
+
"--batch-size",
|
| 259 |
+
help="Batch size for processing videos",
|
| 260 |
+
),
|
| 261 |
+
) -> None:
|
| 262 |
+
"""Compute reference videos for IC-LoRA training.
|
| 263 |
+
This script generates reference videos (e.g., Canny edge maps) for given videos.
|
| 264 |
+
The paths in the output file will be relative to the output file's directory.
|
| 265 |
+
Examples:
|
| 266 |
+
# Process all videos in a directory
|
| 267 |
+
compute_reference.py videos_dir/ -o videos_dir/captions.json
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
# Ensure output path is absolute
|
| 271 |
+
output = Path(output).resolve()
|
| 272 |
+
console.print(f"Output will be saved to [bold blue]{output}[/]")
|
| 273 |
+
|
| 274 |
+
# Verify output path exists
|
| 275 |
+
if not output.exists():
|
| 276 |
+
raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.")
|
| 277 |
+
|
| 278 |
+
# Process media files
|
| 279 |
+
process_media(
|
| 280 |
+
input_path=input_path,
|
| 281 |
+
output_path=output,
|
| 282 |
+
override=override,
|
| 283 |
+
batch_size=batch_size,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
app()
|
packages/ltx-trainer/scripts/decode_latents.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Decode precomputed video latents back into videos using the VAE.
|
| 5 |
+
This script loads latent files saved during preprocessing and decodes them
|
| 6 |
+
back into video clips using the same VAE model.
|
| 7 |
+
Basic usage:
|
| 8 |
+
python scripts/decode_latents.py /path/to/latents/dir /path/to/output \
|
| 9 |
+
--model-source /path/to/ltx2.safetensors
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
import torchvision.utils
|
| 17 |
+
import typer
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from rich.console import Console
|
| 20 |
+
from rich.progress import (
|
| 21 |
+
BarColumn,
|
| 22 |
+
MofNCompleteColumn,
|
| 23 |
+
Progress,
|
| 24 |
+
SpinnerColumn,
|
| 25 |
+
TextColumn,
|
| 26 |
+
TimeElapsedColumn,
|
| 27 |
+
TimeRemainingColumn,
|
| 28 |
+
)
|
| 29 |
+
from transformers.utils.logging import disable_progress_bar
|
| 30 |
+
|
| 31 |
+
from ltx_core.model.video_vae import SpatialTilingConfig, TemporalTilingConfig, TilingConfig
|
| 32 |
+
from ltx_trainer import logger
|
| 33 |
+
from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder
|
| 34 |
+
from ltx_trainer.video_utils import save_video
|
| 35 |
+
|
| 36 |
+
DEFAULT_TILE_SIZE_PIXELS = 512 # Spatial tile size in pixels (must be ≥64 and divisible by 32)
|
| 37 |
+
DEFAULT_TILE_OVERLAP_PIXELS = 128 # Spatial tile overlap in pixels (must be divisible by 32)
|
| 38 |
+
DEFAULT_TILE_SIZE_FRAMES = 128 # Temporal tile size in frames (must be ≥16 and divisible by 8)
|
| 39 |
+
DEFAULT_TILE_OVERLAP_FRAMES = 24 # Temporal tile overlap in frames (must be divisible by 8)
|
| 40 |
+
|
| 41 |
+
disable_progress_bar()
|
| 42 |
+
console = Console()
|
| 43 |
+
app = typer.Typer(
|
| 44 |
+
pretty_exceptions_enable=False,
|
| 45 |
+
no_args_is_help=True,
|
| 46 |
+
help="Decode precomputed video latents back into videos using the VAE.",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LatentsDecoder:
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
model_path: str,
|
| 54 |
+
device: str = "cuda",
|
| 55 |
+
vae_tiling: bool = False,
|
| 56 |
+
with_audio: bool = False,
|
| 57 |
+
):
|
| 58 |
+
"""Initialize the decoder with model configuration.
|
| 59 |
+
Args:
|
| 60 |
+
model_path: Path to LTX-2 checkpoint (.safetensors)
|
| 61 |
+
device: Device to use for computation
|
| 62 |
+
vae_tiling: Whether to enable VAE tiling for larger video resolutions
|
| 63 |
+
with_audio: Whether to load audio VAE for audio decoding
|
| 64 |
+
"""
|
| 65 |
+
self.device = torch.device(device)
|
| 66 |
+
self.model_path = model_path
|
| 67 |
+
self.vae = None
|
| 68 |
+
self.audio_vae = None
|
| 69 |
+
self.vocoder = None
|
| 70 |
+
self.vae_tiling = vae_tiling
|
| 71 |
+
|
| 72 |
+
self._load_model(model_path, with_audio)
|
| 73 |
+
|
| 74 |
+
def _load_model(self, model_path: str, with_audio: bool = False) -> None:
|
| 75 |
+
"""Initialize and load the VAE model(s)."""
|
| 76 |
+
with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"):
|
| 77 |
+
self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)
|
| 78 |
+
|
| 79 |
+
if with_audio:
|
| 80 |
+
with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"):
|
| 81 |
+
self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)
|
| 82 |
+
|
| 83 |
+
with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"):
|
| 84 |
+
self.vocoder = load_vocoder(model_path, device=self.device)
|
| 85 |
+
|
| 86 |
+
@torch.inference_mode()
|
| 87 |
+
def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None:
|
| 88 |
+
"""Decode all latent files in the directory recursively.
|
| 89 |
+
Args:
|
| 90 |
+
latents_dir: Directory containing latent files (.pt)
|
| 91 |
+
output_dir: Directory to save decoded videos
|
| 92 |
+
seed: Optional random seed for noise generation
|
| 93 |
+
"""
|
| 94 |
+
# Find all .pt files recursively
|
| 95 |
+
latent_files = list(latents_dir.rglob("*.pt"))
|
| 96 |
+
|
| 97 |
+
if not latent_files:
|
| 98 |
+
logger.warning(f"No .pt files found in {latents_dir}")
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
logger.info(f"Found {len(latent_files):,} latent files to decode")
|
| 102 |
+
|
| 103 |
+
# Process files with progress bar
|
| 104 |
+
with Progress(
|
| 105 |
+
SpinnerColumn(),
|
| 106 |
+
TextColumn("[progress.description]{task.description}"),
|
| 107 |
+
BarColumn(),
|
| 108 |
+
MofNCompleteColumn(),
|
| 109 |
+
TimeElapsedColumn(),
|
| 110 |
+
TimeRemainingColumn(),
|
| 111 |
+
console=console,
|
| 112 |
+
) as progress:
|
| 113 |
+
task = progress.add_task("Decoding latents", total=len(latent_files))
|
| 114 |
+
|
| 115 |
+
for latent_file in latent_files:
|
| 116 |
+
# Calculate relative path to maintain directory structure
|
| 117 |
+
rel_path = latent_file.relative_to(latents_dir)
|
| 118 |
+
output_subdir = output_dir / rel_path.parent
|
| 119 |
+
output_subdir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
self._process_file(latent_file, output_subdir, seed)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Error processing {latent_file}: {e}")
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
progress.advance(task)
|
| 128 |
+
|
| 129 |
+
logger.info(f"Decoding complete! Videos saved to {output_dir}")
|
| 130 |
+
|
| 131 |
+
@torch.inference_mode()
|
| 132 |
+
def decode_audio(self, latents_dir: Path, output_dir: Path) -> None:
|
| 133 |
+
"""Decode all audio latent files in the directory recursively.
|
| 134 |
+
Args:
|
| 135 |
+
latents_dir: Directory containing audio latent files (.pt)
|
| 136 |
+
output_dir: Directory to save decoded audio files
|
| 137 |
+
"""
|
| 138 |
+
# Check if audio VAE is loaded
|
| 139 |
+
if self.audio_vae is None or self.vocoder is None:
|
| 140 |
+
logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.")
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
# Find all .pt files recursively
|
| 144 |
+
latent_files = list(latents_dir.rglob("*.pt"))
|
| 145 |
+
|
| 146 |
+
if not latent_files:
|
| 147 |
+
logger.warning(f"No .pt files found in {latents_dir}")
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
logger.info(f"Found {len(latent_files):,} audio latent files to decode")
|
| 151 |
+
|
| 152 |
+
# Process files with progress bar
|
| 153 |
+
with Progress(
|
| 154 |
+
SpinnerColumn(),
|
| 155 |
+
TextColumn("[progress.description]{task.description}"),
|
| 156 |
+
BarColumn(),
|
| 157 |
+
MofNCompleteColumn(),
|
| 158 |
+
TimeElapsedColumn(),
|
| 159 |
+
TimeRemainingColumn(),
|
| 160 |
+
console=console,
|
| 161 |
+
) as progress:
|
| 162 |
+
task = progress.add_task("Decoding audio latents", total=len(latent_files))
|
| 163 |
+
|
| 164 |
+
for latent_file in latent_files:
|
| 165 |
+
# Calculate relative path to maintain directory structure
|
| 166 |
+
rel_path = latent_file.relative_to(latents_dir)
|
| 167 |
+
output_subdir = output_dir / rel_path.parent
|
| 168 |
+
output_subdir.mkdir(parents=True, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
self._process_audio_file(latent_file, output_subdir)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error(f"Error processing audio {latent_file}: {e}")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
progress.advance(task)
|
| 177 |
+
|
| 178 |
+
logger.info(f"Audio decoding complete! Audio files saved to {output_dir}")
|
| 179 |
+
|
| 180 |
+
def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None:
|
| 181 |
+
"""Process a single latent file."""
|
| 182 |
+
# Load the latent data
|
| 183 |
+
data = torch.load(latent_file, map_location=self.device, weights_only=False)
|
| 184 |
+
|
| 185 |
+
# Get latents - handle both old patchified [seq_len, C] and new [C, F, H, W] formats
|
| 186 |
+
latents = data["latents"]
|
| 187 |
+
num_frames = data["num_frames"]
|
| 188 |
+
height = data["height"]
|
| 189 |
+
width = data["width"]
|
| 190 |
+
|
| 191 |
+
# Check if latents need reshaping (old patchified format)
|
| 192 |
+
if latents.dim() == 2:
|
| 193 |
+
# Old format: [seq_len, C] -> reshape to [C, F, H, W]
|
| 194 |
+
latents = rearrange(latents, "(f h w) c -> c f h w", f=num_frames, h=height, w=width)
|
| 195 |
+
|
| 196 |
+
# Add batch dimension: [C, F, H, W] -> [1, C, F, H, W]
|
| 197 |
+
latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16)
|
| 198 |
+
|
| 199 |
+
# Create generator only if seed is provided
|
| 200 |
+
generator = None
|
| 201 |
+
if seed is not None:
|
| 202 |
+
generator = torch.Generator(device=self.device)
|
| 203 |
+
generator.manual_seed(seed)
|
| 204 |
+
|
| 205 |
+
# Decode the video
|
| 206 |
+
video = self._decode_video(latents, generator)
|
| 207 |
+
|
| 208 |
+
# Determine output format and save
|
| 209 |
+
is_image = video.shape[0] == 1
|
| 210 |
+
if is_image:
|
| 211 |
+
# Save as PNG for single frame
|
| 212 |
+
output_path = output_dir / f"{latent_file.stem}.png"
|
| 213 |
+
torchvision.utils.save_image(
|
| 214 |
+
video[0], # [C, H, W] in [0, 1]
|
| 215 |
+
str(output_path),
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
# Save as MP4 for video using PyAV-based save_video
|
| 219 |
+
output_path = output_dir / f"{latent_file.stem}.mp4"
|
| 220 |
+
fps = data.get("fps", 24) # Use stored FPS or default to 24
|
| 221 |
+
save_video(
|
| 222 |
+
video_tensor=video, # [F, C, H, W] in [0, 1]
|
| 223 |
+
output_path=output_path,
|
| 224 |
+
fps=fps,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def _decode_video(self, latents: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
| 228 |
+
"""Decode latents to video frames."""
|
| 229 |
+
if self.vae_tiling:
|
| 230 |
+
# Use tiled decoding for reduced VRAM
|
| 231 |
+
tiling_config = TilingConfig(
|
| 232 |
+
spatial_config=SpatialTilingConfig(
|
| 233 |
+
tile_size_in_pixels=DEFAULT_TILE_SIZE_PIXELS,
|
| 234 |
+
tile_overlap_in_pixels=DEFAULT_TILE_OVERLAP_PIXELS,
|
| 235 |
+
),
|
| 236 |
+
temporal_config=TemporalTilingConfig(
|
| 237 |
+
tile_size_in_frames=DEFAULT_TILE_SIZE_FRAMES,
|
| 238 |
+
tile_overlap_in_frames=DEFAULT_TILE_OVERLAP_FRAMES,
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
chunks = list(
|
| 242 |
+
self.vae.tiled_decode(
|
| 243 |
+
latents,
|
| 244 |
+
tiling_config=tiling_config,
|
| 245 |
+
generator=generator,
|
| 246 |
+
)
|
| 247 |
+
)
|
| 248 |
+
# Concatenate along temporal dimension
|
| 249 |
+
video = torch.cat(chunks, dim=2) # [B, C, F, H, W]
|
| 250 |
+
else:
|
| 251 |
+
# Standard full decoding
|
| 252 |
+
video = self.vae(latents, generator=generator) # [B, C, F, H, W]
|
| 253 |
+
|
| 254 |
+
# Convert to [F, C, H, W] format and normalize to [0, 1]
|
| 255 |
+
video = rearrange(video, "1 c f h w -> f c h w")
|
| 256 |
+
video = (video + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
|
| 257 |
+
video = video.clamp(0, 1)
|
| 258 |
+
|
| 259 |
+
return video
|
| 260 |
+
|
| 261 |
+
def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None:
|
| 262 |
+
"""Process a single audio latent file."""
|
| 263 |
+
# Load the latent data
|
| 264 |
+
data = torch.load(latent_file, map_location=self.device, weights_only=False)
|
| 265 |
+
|
| 266 |
+
latents = data["latents"].to(device=self.device, dtype=torch.float32)
|
| 267 |
+
num_time_steps = data["num_time_steps"]
|
| 268 |
+
freq_bins = data["frequency_bins"]
|
| 269 |
+
|
| 270 |
+
# Handle both old patchified [seq_len, C] and new [C, T, F] formats
|
| 271 |
+
if latents.dim() == 2:
|
| 272 |
+
# Old format: [seq_len, channels] where seq_len = time * freq
|
| 273 |
+
# Reshape to [C, T, F]
|
| 274 |
+
latents = rearrange(latents, "(t f) c -> c t f", t=num_time_steps, f=freq_bins)
|
| 275 |
+
|
| 276 |
+
# Add batch dimension: [C, T, F] -> [1, C, T, F]
|
| 277 |
+
latents = latents.unsqueeze(0)
|
| 278 |
+
|
| 279 |
+
# Set correct dtype for audio VAE
|
| 280 |
+
latents = latents.to(dtype=torch.bfloat16)
|
| 281 |
+
|
| 282 |
+
# Decode audio using audio VAE decoder (produces mel spectrogram)
|
| 283 |
+
mel_spectrogram = self.audio_vae(latents)
|
| 284 |
+
|
| 285 |
+
# Convert mel spectrogram to waveform using vocoder
|
| 286 |
+
waveform = self.vocoder(mel_spectrogram)
|
| 287 |
+
|
| 288 |
+
# Save as WAV
|
| 289 |
+
output_path = output_dir / f"{latent_file.stem}.wav"
|
| 290 |
+
sample_rate = self.vocoder.output_sampling_rate
|
| 291 |
+
torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@app.command()
|
| 295 |
+
def main(
|
| 296 |
+
latents_dir: str = typer.Argument(
|
| 297 |
+
...,
|
| 298 |
+
help="Directory containing the precomputed latent files (searched recursively)",
|
| 299 |
+
),
|
| 300 |
+
output_dir: str = typer.Argument(
|
| 301 |
+
...,
|
| 302 |
+
help="Directory to save the decoded videos (maintains same folder hierarchy as input)",
|
| 303 |
+
),
|
| 304 |
+
model_path: str = typer.Option(
|
| 305 |
+
...,
|
| 306 |
+
help="Path to LTX-2 checkpoint (.safetensors file)",
|
| 307 |
+
),
|
| 308 |
+
device: str = typer.Option(
|
| 309 |
+
default="cuda",
|
| 310 |
+
help="Device to use for computation",
|
| 311 |
+
),
|
| 312 |
+
vae_tiling: bool = typer.Option(
|
| 313 |
+
default=False,
|
| 314 |
+
help="Enable VAE tiling for larger video resolutions",
|
| 315 |
+
),
|
| 316 |
+
seed: int | None = typer.Option(
|
| 317 |
+
default=None,
|
| 318 |
+
help="Random seed for noise generation during decoding",
|
| 319 |
+
),
|
| 320 |
+
with_audio: bool = typer.Option(
|
| 321 |
+
default=False,
|
| 322 |
+
help="Also decode audio latents (requires audio_latents directory)",
|
| 323 |
+
),
|
| 324 |
+
audio_latents_dir: str | None = typer.Option(
|
| 325 |
+
default=None,
|
| 326 |
+
help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)",
|
| 327 |
+
),
|
| 328 |
+
) -> None:
|
| 329 |
+
"""Decode precomputed video latents back into videos using the VAE.
|
| 330 |
+
This script recursively searches for .pt latent files in the input directory
|
| 331 |
+
and decodes them to videos, maintaining the same folder hierarchy in the output.
|
| 332 |
+
Examples:
|
| 333 |
+
# Basic usage
|
| 334 |
+
python scripts/decode_latents.py /path/to/latents /path/to/videos \\
|
| 335 |
+
--model-path /path/to/ltx2.safetensors
|
| 336 |
+
# With VAE tiling for large videos
|
| 337 |
+
python scripts/decode_latents.py /path/to/latents /path/to/videos \\
|
| 338 |
+
--model-path /path/to/ltx2.safetensors --vae-tiling
|
| 339 |
+
# With audio decoding
|
| 340 |
+
python scripts/decode_latents.py /path/to/latents /path/to/videos \\
|
| 341 |
+
--model-path /path/to/ltx2.safetensors --with-audio
|
| 342 |
+
"""
|
| 343 |
+
latents_path = Path(latents_dir)
|
| 344 |
+
output_path = Path(output_dir)
|
| 345 |
+
|
| 346 |
+
if not latents_path.exists() or not latents_path.is_dir():
|
| 347 |
+
raise typer.BadParameter(f"Latents directory does not exist: {latents_path}")
|
| 348 |
+
|
| 349 |
+
decoder = LatentsDecoder(
|
| 350 |
+
model_path=model_path,
|
| 351 |
+
device=device,
|
| 352 |
+
vae_tiling=vae_tiling,
|
| 353 |
+
with_audio=with_audio,
|
| 354 |
+
)
|
| 355 |
+
decoder.decode(latents_path, output_path, seed=seed)
|
| 356 |
+
|
| 357 |
+
# Decode audio if requested
|
| 358 |
+
if with_audio:
|
| 359 |
+
audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents"
|
| 360 |
+
|
| 361 |
+
if audio_path.exists():
|
| 362 |
+
audio_output_path = output_path.parent / "decoded_audio"
|
| 363 |
+
decoder.decode_audio(audio_path, audio_output_path)
|
| 364 |
+
else:
|
| 365 |
+
logger.warning(f"Audio latents directory not found: {audio_path}")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
app()
|
packages/ltx-trainer/scripts/process_captions.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Compute text embeddings for video generation training.
|
| 5 |
+
This module provides functionality for processing text captions, including:
|
| 6 |
+
- Loading captions from various file formats (CSV, JSON, JSONL)
|
| 7 |
+
- Cleaning and preprocessing text (removing LLM prefixes, adding ID tokens)
|
| 8 |
+
- CaptionsDataset for caption-only preprocessing workflows
|
| 9 |
+
Can be used as a standalone script:
|
| 10 |
+
python scripts/process_captions.py dataset.json --output-dir /path/to/output \
|
| 11 |
+
--model-source /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import torch
|
| 21 |
+
import typer
|
| 22 |
+
from rich.console import Console
|
| 23 |
+
from rich.progress import (
|
| 24 |
+
BarColumn,
|
| 25 |
+
MofNCompleteColumn,
|
| 26 |
+
Progress,
|
| 27 |
+
SpinnerColumn,
|
| 28 |
+
TaskProgressColumn,
|
| 29 |
+
TextColumn,
|
| 30 |
+
TimeElapsedColumn,
|
| 31 |
+
TimeRemainingColumn,
|
| 32 |
+
)
|
| 33 |
+
from torch.utils.data import DataLoader, Dataset
|
| 34 |
+
from transformers.utils.logging import disable_progress_bar
|
| 35 |
+
|
| 36 |
+
from ltx_trainer import logger
|
| 37 |
+
from ltx_trainer.model_loader import load_embeddings_processor, load_text_encoder
|
| 38 |
+
|
| 39 |
+
# Disable tokenizers parallelism to avoid warnings
|
| 40 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 41 |
+
|
| 42 |
+
disable_progress_bar()
|
| 43 |
+
|
| 44 |
+
# Common phrases that LLMs often add to captions that we might want to remove
|
| 45 |
+
COMMON_BEGINNING_PHRASES: tuple[str, ...] = (
|
| 46 |
+
"This video",
|
| 47 |
+
"The video",
|
| 48 |
+
"This clip",
|
| 49 |
+
"The clip",
|
| 50 |
+
"The animation",
|
| 51 |
+
"This image",
|
| 52 |
+
"The image",
|
| 53 |
+
"This picture",
|
| 54 |
+
"The picture",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
COMMON_CONTINUATION_WORDS: tuple[str, ...] = (
|
| 58 |
+
"shows",
|
| 59 |
+
"depicts",
|
| 60 |
+
"features",
|
| 61 |
+
"captures",
|
| 62 |
+
"highlights",
|
| 63 |
+
"introduces",
|
| 64 |
+
"presents",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
COMMON_LLM_START_PHRASES: tuple[str, ...] = (
|
| 68 |
+
"In the video,",
|
| 69 |
+
"In this video,",
|
| 70 |
+
"In this video clip,",
|
| 71 |
+
"In the clip,",
|
| 72 |
+
"Caption:",
|
| 73 |
+
*(
|
| 74 |
+
f"{beginning} {continuation}"
|
| 75 |
+
for beginning in COMMON_BEGINNING_PHRASES
|
| 76 |
+
for continuation in COMMON_CONTINUATION_WORDS
|
| 77 |
+
),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
app = typer.Typer(
|
| 81 |
+
pretty_exceptions_enable=False,
|
| 82 |
+
no_args_is_help=True,
|
| 83 |
+
help="Process text captions and save embeddings for video generation training.",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class CaptionsDataset(Dataset):
|
| 88 |
+
"""
|
| 89 |
+
Dataset for processing text captions only.
|
| 90 |
+
This dataset is designed for caption preprocessing workflows where you only need
|
| 91 |
+
to process text without loading videos. Useful for:
|
| 92 |
+
- Precomputing text embeddings
|
| 93 |
+
- Caption cleaning and preprocessing
|
| 94 |
+
- Text-only preprocessing pipelines
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
dataset_file: str | Path,
|
| 100 |
+
caption_column: str,
|
| 101 |
+
media_column: str = "media_path",
|
| 102 |
+
lora_trigger: str | None = None,
|
| 103 |
+
remove_llm_prefixes: bool = False,
|
| 104 |
+
) -> None:
|
| 105 |
+
"""
|
| 106 |
+
Initialize the captions dataset.
|
| 107 |
+
Args:
|
| 108 |
+
dataset_file: Path to CSV/JSON/JSONL metadata file
|
| 109 |
+
caption_column: Column name for captions in the metadata file
|
| 110 |
+
media_column: Column name for media paths (used for output naming)
|
| 111 |
+
lora_trigger: Optional trigger word to prepend to each caption
|
| 112 |
+
remove_llm_prefixes: Whether to remove common LLM-generated prefixes
|
| 113 |
+
"""
|
| 114 |
+
super().__init__()
|
| 115 |
+
|
| 116 |
+
self.dataset_file = Path(dataset_file)
|
| 117 |
+
self.caption_column = caption_column
|
| 118 |
+
self.media_column = media_column
|
| 119 |
+
self.lora_trigger = f"{lora_trigger.strip()} " if lora_trigger else ""
|
| 120 |
+
|
| 121 |
+
# Load captions with their corresponding output embedding paths
|
| 122 |
+
self.caption_data = self._load_caption_data()
|
| 123 |
+
|
| 124 |
+
# Convert to lists for indexing
|
| 125 |
+
self.output_paths = list(self.caption_data.keys())
|
| 126 |
+
self.prompts = list(self.caption_data.values())
|
| 127 |
+
|
| 128 |
+
# Clean LLM start phrases if requested
|
| 129 |
+
if remove_llm_prefixes:
|
| 130 |
+
self._clean_llm_prefixes()
|
| 131 |
+
|
| 132 |
+
def __len__(self) -> int:
|
| 133 |
+
return len(self.prompts)
|
| 134 |
+
|
| 135 |
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 136 |
+
"""Get a single caption with optional trigger word prepended and output path."""
|
| 137 |
+
prompt = self.lora_trigger + self.prompts[index]
|
| 138 |
+
return {
|
| 139 |
+
"prompt": prompt,
|
| 140 |
+
"output_path": self.output_paths[index],
|
| 141 |
+
"index": index,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
def _load_caption_data(self) -> dict[str, str]:
|
| 145 |
+
"""Load captions and compute their output embedding paths."""
|
| 146 |
+
if self.dataset_file.suffix == ".csv":
|
| 147 |
+
return self._load_caption_data_from_csv()
|
| 148 |
+
elif self.dataset_file.suffix == ".json":
|
| 149 |
+
return self._load_caption_data_from_json()
|
| 150 |
+
elif self.dataset_file.suffix == ".jsonl":
|
| 151 |
+
return self._load_caption_data_from_jsonl()
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.")
|
| 154 |
+
|
| 155 |
+
def _load_caption_data_from_csv(self) -> dict[str, str]:
|
| 156 |
+
"""Load captions from a CSV file and compute output embedding paths."""
|
| 157 |
+
df = pd.read_csv(self.dataset_file)
|
| 158 |
+
|
| 159 |
+
if self.caption_column not in df.columns:
|
| 160 |
+
raise ValueError(f"Column '{self.caption_column}' not found in CSV file")
|
| 161 |
+
if self.media_column not in df.columns:
|
| 162 |
+
raise ValueError(f"Column '{self.media_column}' not found in CSV file")
|
| 163 |
+
|
| 164 |
+
caption_data = {}
|
| 165 |
+
for _, row in df.iterrows():
|
| 166 |
+
media_path = Path(row[self.media_column].strip())
|
| 167 |
+
# Convert media path to embedding output path (same structure, .pt extension)
|
| 168 |
+
output_path = str(media_path.with_suffix(".pt"))
|
| 169 |
+
caption_data[output_path] = row[self.caption_column]
|
| 170 |
+
|
| 171 |
+
return caption_data
|
| 172 |
+
|
| 173 |
+
def _load_caption_data_from_json(self) -> dict[str, str]:
|
| 174 |
+
"""Load captions from a JSON file and compute output embedding paths."""
|
| 175 |
+
with open(self.dataset_file, "r", encoding="utf-8") as file:
|
| 176 |
+
data = json.load(file)
|
| 177 |
+
|
| 178 |
+
if not isinstance(data, list):
|
| 179 |
+
raise ValueError("JSON file must contain a list of objects")
|
| 180 |
+
|
| 181 |
+
caption_data = {}
|
| 182 |
+
for entry in data:
|
| 183 |
+
if self.caption_column not in entry:
|
| 184 |
+
raise ValueError(f"Key '{self.caption_column}' not found in JSON entry: {entry}")
|
| 185 |
+
if self.media_column not in entry:
|
| 186 |
+
raise ValueError(f"Key '{self.media_column}' not found in JSON entry: {entry}")
|
| 187 |
+
|
| 188 |
+
media_path = Path(entry[self.media_column].strip())
|
| 189 |
+
# Convert media path to embedding output path (same structure, .pt extension)
|
| 190 |
+
output_path = str(media_path.with_suffix(".pt"))
|
| 191 |
+
caption_data[output_path] = entry[self.caption_column]
|
| 192 |
+
|
| 193 |
+
return caption_data
|
| 194 |
+
|
| 195 |
+
def _load_caption_data_from_jsonl(self) -> dict[str, str]:
|
| 196 |
+
"""Load captions from a JSONL file and compute output embedding paths."""
|
| 197 |
+
caption_data = {}
|
| 198 |
+
with open(self.dataset_file, "r", encoding="utf-8") as file:
|
| 199 |
+
for line in file:
|
| 200 |
+
entry = json.loads(line)
|
| 201 |
+
if self.caption_column not in entry:
|
| 202 |
+
raise ValueError(f"Key '{self.caption_column}' not found in JSONL entry: {entry}")
|
| 203 |
+
if self.media_column not in entry:
|
| 204 |
+
raise ValueError(f"Key '{self.media_column}' not found in JSONL entry: {entry}")
|
| 205 |
+
|
| 206 |
+
media_path = Path(entry[self.media_column].strip())
|
| 207 |
+
# Convert media path to embedding output path (same structure, .pt extension)
|
| 208 |
+
output_path = str(media_path.with_suffix(".pt"))
|
| 209 |
+
caption_data[output_path] = entry[self.caption_column]
|
| 210 |
+
|
| 211 |
+
return caption_data
|
| 212 |
+
|
| 213 |
+
def _clean_llm_prefixes(self) -> None:
|
| 214 |
+
"""Remove common LLM-generated prefixes from captions."""
|
| 215 |
+
for i in range(len(self.prompts)):
|
| 216 |
+
self.prompts[i] = self.prompts[i].strip()
|
| 217 |
+
for phrase in COMMON_LLM_START_PHRASES:
|
| 218 |
+
if self.prompts[i].startswith(phrase):
|
| 219 |
+
self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def compute_captions_embeddings( # noqa: PLR0913
|
| 224 |
+
dataset_file: str | Path,
|
| 225 |
+
output_dir: str,
|
| 226 |
+
model_path: str,
|
| 227 |
+
text_encoder_path: str,
|
| 228 |
+
caption_column: str = "caption",
|
| 229 |
+
media_column: str = "media_path",
|
| 230 |
+
lora_trigger: str | None = None,
|
| 231 |
+
remove_llm_prefixes: bool = False,
|
| 232 |
+
batch_size: int = 8,
|
| 233 |
+
device: str = "cuda",
|
| 234 |
+
load_in_8bit: bool = False,
|
| 235 |
+
) -> None:
|
| 236 |
+
"""
|
| 237 |
+
Process captions and save text embeddings.
|
| 238 |
+
Args:
|
| 239 |
+
dataset_file: Path to metadata file (CSV/JSON/JSONL) containing captions and media paths
|
| 240 |
+
output_dir: Directory to save embeddings
|
| 241 |
+
model_path: Path to LTX-2 checkpoint (.safetensors)
|
| 242 |
+
text_encoder_path: Path to Gemma text encoder directory
|
| 243 |
+
caption_column: Column name containing captions in the metadata file
|
| 244 |
+
media_column: Column name containing media paths (used for output naming)
|
| 245 |
+
lora_trigger: Optional trigger word to prepend to each caption
|
| 246 |
+
remove_llm_prefixes: Whether to remove common LLM-generated prefixes
|
| 247 |
+
batch_size: Batch size for processing
|
| 248 |
+
device: Device to use for computation
|
| 249 |
+
load_in_8bit: Whether to load the Gemma text encoder in 8-bit precision
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
console = Console()
|
| 253 |
+
|
| 254 |
+
# Create dataset
|
| 255 |
+
dataset = CaptionsDataset(
|
| 256 |
+
dataset_file=dataset_file,
|
| 257 |
+
caption_column=caption_column,
|
| 258 |
+
media_column=media_column,
|
| 259 |
+
lora_trigger=lora_trigger,
|
| 260 |
+
remove_llm_prefixes=remove_llm_prefixes,
|
| 261 |
+
)
|
| 262 |
+
logger.info(f"Loaded {len(dataset):,} captions")
|
| 263 |
+
|
| 264 |
+
output_path = Path(output_dir)
|
| 265 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 266 |
+
|
| 267 |
+
# Load text encoder and embeddings processor
|
| 268 |
+
with console.status("[bold]Loading Gemma text encoder...", spinner="dots"):
|
| 269 |
+
text_encoder = load_text_encoder(
|
| 270 |
+
text_encoder_path,
|
| 271 |
+
device=device,
|
| 272 |
+
dtype=torch.bfloat16,
|
| 273 |
+
load_in_8bit=load_in_8bit,
|
| 274 |
+
)
|
| 275 |
+
embeddings_processor = load_embeddings_processor(
|
| 276 |
+
model_path,
|
| 277 |
+
device=device,
|
| 278 |
+
dtype=torch.bfloat16,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
logger.info("Text encoder and embeddings processor loaded successfully")
|
| 282 |
+
|
| 283 |
+
# TODO(batch-tokenization): The current Gemma tokenizer doesn't support batched tokenization.
|
| 284 |
+
if batch_size > 1:
|
| 285 |
+
logger.warning(
|
| 286 |
+
"Batch size greater than 1 is not currently supported with the Gemma tokenizer. "
|
| 287 |
+
"Overriding batch_size to 1. This will be fixed in a future update."
|
| 288 |
+
)
|
| 289 |
+
batch_size = 1
|
| 290 |
+
|
| 291 |
+
# Create dataloader
|
| 292 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
|
| 293 |
+
|
| 294 |
+
# Process batches
|
| 295 |
+
total_batches = len(dataloader)
|
| 296 |
+
logger.info(f"Processing captions in {total_batches:,} batches...")
|
| 297 |
+
|
| 298 |
+
with Progress(
|
| 299 |
+
SpinnerColumn(),
|
| 300 |
+
TextColumn("[progress.description]{task.description}"),
|
| 301 |
+
BarColumn(),
|
| 302 |
+
TaskProgressColumn(),
|
| 303 |
+
MofNCompleteColumn(),
|
| 304 |
+
TimeElapsedColumn(),
|
| 305 |
+
TimeRemainingColumn(),
|
| 306 |
+
console=console,
|
| 307 |
+
) as progress:
|
| 308 |
+
task = progress.add_task("Processing captions", total=len(dataloader))
|
| 309 |
+
for batch in dataloader:
|
| 310 |
+
# Encode prompts using text_encoder.encode() + feature_extractor
|
| 311 |
+
# (returns video/audio features before connector).
|
| 312 |
+
# The connector is applied during training via embeddings_processor
|
| 313 |
+
with torch.inference_mode():
|
| 314 |
+
# TODO(batch-tokenization): When tokenizer supports batching, encode all prompts at once.
|
| 315 |
+
# For now, process one at a time:
|
| 316 |
+
for i in range(len(batch["prompt"])):
|
| 317 |
+
hidden_states, prompt_attention_mask = text_encoder.encode(batch["prompt"][i], padding_side="left")
|
| 318 |
+
video_prompt_embeds, audio_prompt_embeds = embeddings_processor.feature_extractor(
|
| 319 |
+
hidden_states, prompt_attention_mask, "left"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
output_rel_path = Path(batch["output_path"][i])
|
| 323 |
+
|
| 324 |
+
# Create output directory maintaining structure
|
| 325 |
+
output_dir_path = output_path / output_rel_path.parent
|
| 326 |
+
output_dir_path.mkdir(parents=True, exist_ok=True)
|
| 327 |
+
|
| 328 |
+
embedding_data = {
|
| 329 |
+
"video_prompt_embeds": video_prompt_embeds[0].cpu().contiguous(),
|
| 330 |
+
"prompt_attention_mask": prompt_attention_mask[0].cpu().contiguous(),
|
| 331 |
+
}
|
| 332 |
+
if audio_prompt_embeds is not None:
|
| 333 |
+
embedding_data["audio_prompt_embeds"] = audio_prompt_embeds[0].cpu().contiguous()
|
| 334 |
+
|
| 335 |
+
output_file = output_path / output_rel_path
|
| 336 |
+
torch.save(embedding_data, output_file)
|
| 337 |
+
|
| 338 |
+
progress.advance(task)
|
| 339 |
+
|
| 340 |
+
logger.info(f"Processed {len(dataset):,} captions. Embeddings saved to {output_path}")
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@app.command()
|
| 344 |
+
def main( # noqa: PLR0913
|
| 345 |
+
dataset_file: str = typer.Argument(
|
| 346 |
+
...,
|
| 347 |
+
help="Path to metadata file (CSV/JSON/JSONL) containing captions and media paths",
|
| 348 |
+
),
|
| 349 |
+
output_dir: str = typer.Option(
|
| 350 |
+
...,
|
| 351 |
+
help="Output directory to save text embeddings",
|
| 352 |
+
),
|
| 353 |
+
model_path: str = typer.Option(
|
| 354 |
+
...,
|
| 355 |
+
help="Path to LTX-2 checkpoint (.safetensors file)",
|
| 356 |
+
),
|
| 357 |
+
text_encoder_path: str = typer.Option(
|
| 358 |
+
...,
|
| 359 |
+
help="Path to Gemma text encoder directory",
|
| 360 |
+
),
|
| 361 |
+
caption_column: str = typer.Option(
|
| 362 |
+
default="caption",
|
| 363 |
+
help="Column name containing captions in the dataset JSON/JSONL/CSV file",
|
| 364 |
+
),
|
| 365 |
+
media_column: str = typer.Option(
|
| 366 |
+
default="media_path",
|
| 367 |
+
help="Column name in the dataset JSON/JSONL/CSV file containing media paths "
|
| 368 |
+
"(used for output file naming and folder structure)",
|
| 369 |
+
),
|
| 370 |
+
batch_size: int = typer.Option(
|
| 371 |
+
default=8,
|
| 372 |
+
help="Batch size for processing",
|
| 373 |
+
),
|
| 374 |
+
device: str = typer.Option(
|
| 375 |
+
default="cuda",
|
| 376 |
+
help="Device to use for computation",
|
| 377 |
+
),
|
| 378 |
+
lora_trigger: str | None = typer.Option(
|
| 379 |
+
default=None,
|
| 380 |
+
help="Optional trigger word to prepend to each caption (activates the LoRA during inference)",
|
| 381 |
+
),
|
| 382 |
+
remove_llm_prefixes: bool = typer.Option(
|
| 383 |
+
default=False,
|
| 384 |
+
help="Remove common LLM-generated prefixes from captions",
|
| 385 |
+
),
|
| 386 |
+
load_text_encoder_in_8bit: bool = typer.Option(
|
| 387 |
+
default=False,
|
| 388 |
+
help="Load the Gemma text encoder in 8-bit precision to save GPU memory (requires bitsandbytes)",
|
| 389 |
+
),
|
| 390 |
+
) -> None:
|
| 391 |
+
"""Process text captions and save embeddings for video generation training.
|
| 392 |
+
This script processes captions from metadata files and saves text embeddings
|
| 393 |
+
that can be used for training video generation models. The output embeddings
|
| 394 |
+
will maintain the same folder structure and naming as the corresponding media files.
|
| 395 |
+
Note: This script is designed for LTX-2 models which use the Gemma text encoder.
|
| 396 |
+
Examples:
|
| 397 |
+
# Process captions with LTX-2 model
|
| 398 |
+
python scripts/process_captions.py dataset.json --output-dir ./embeddings \\
|
| 399 |
+
--model-path /path/to/ltx2_checkpoint.safetensors \\
|
| 400 |
+
--text-encoder-path /path/to/gemma
|
| 401 |
+
# Add a trigger word for LoRA training
|
| 402 |
+
python scripts/process_captions.py dataset.json --output-dir ./embeddings \\
|
| 403 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
|
| 404 |
+
--lora-trigger "mytoken"
|
| 405 |
+
# Remove LLM-generated prefixes from captions
|
| 406 |
+
python scripts/process_captions.py dataset.json --output-dir ./embeddings \\
|
| 407 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
|
| 408 |
+
--remove-llm-prefixes
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
# Validate dataset file
|
| 412 |
+
if not Path(dataset_file).is_file():
|
| 413 |
+
raise typer.BadParameter(f"Dataset file not found: {dataset_file}")
|
| 414 |
+
|
| 415 |
+
if lora_trigger:
|
| 416 |
+
logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions')
|
| 417 |
+
|
| 418 |
+
# Process embeddings
|
| 419 |
+
compute_captions_embeddings(
|
| 420 |
+
dataset_file=dataset_file,
|
| 421 |
+
output_dir=output_dir,
|
| 422 |
+
model_path=model_path,
|
| 423 |
+
text_encoder_path=text_encoder_path,
|
| 424 |
+
caption_column=caption_column,
|
| 425 |
+
media_column=media_column,
|
| 426 |
+
lora_trigger=lora_trigger,
|
| 427 |
+
remove_llm_prefixes=remove_llm_prefixes,
|
| 428 |
+
batch_size=batch_size,
|
| 429 |
+
device=device,
|
| 430 |
+
load_in_8bit=load_text_encoder_in_8bit,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
app()
|
packages/ltx-trainer/scripts/process_dataset.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Preprocess a video dataset by computing video clips latents and text captions embeddings.
|
| 5 |
+
This script provides a command-line interface for preprocessing video datasets by computing
|
| 6 |
+
latent representations of video clips and text embeddings of their captions. The preprocessed
|
| 7 |
+
data can be used to accelerate training of video generation models and to save GPU memory.
|
| 8 |
+
Basic usage:
|
| 9 |
+
python scripts/process_dataset.py /path/to/dataset.json --resolution-buckets 768x768x49 \
|
| 10 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma
|
| 11 |
+
The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import typer
|
| 17 |
+
from decode_latents import LatentsDecoder
|
| 18 |
+
from process_captions import compute_captions_embeddings
|
| 19 |
+
from process_videos import compute_latents, compute_scaled_resolution_buckets, parse_resolution_buckets
|
| 20 |
+
from rich.console import Console
|
| 21 |
+
|
| 22 |
+
from ltx_trainer import logger
|
| 23 |
+
from ltx_trainer.gpu_utils import free_gpu_memory_context
|
| 24 |
+
|
| 25 |
+
console = Console()
|
| 26 |
+
|
| 27 |
+
app = typer.Typer(
|
| 28 |
+
pretty_exceptions_enable=False,
|
| 29 |
+
no_args_is_help=True,
|
| 30 |
+
help="Preprocess a video dataset by computing video clips latents and text captions embeddings. "
|
| 31 |
+
"The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess_dataset( # noqa: PLR0913
|
| 36 |
+
dataset_file: str,
|
| 37 |
+
caption_column: str,
|
| 38 |
+
video_column: str,
|
| 39 |
+
resolution_buckets: list[tuple[int, int, int]],
|
| 40 |
+
batch_size: int,
|
| 41 |
+
output_dir: str | None,
|
| 42 |
+
lora_trigger: str | None,
|
| 43 |
+
vae_tiling: bool,
|
| 44 |
+
decode: bool,
|
| 45 |
+
model_path: str,
|
| 46 |
+
text_encoder_path: str,
|
| 47 |
+
device: str,
|
| 48 |
+
remove_llm_prefixes: bool = False,
|
| 49 |
+
reference_column: str | None = None,
|
| 50 |
+
reference_downscale_factor: int = 1,
|
| 51 |
+
with_audio: bool = False,
|
| 52 |
+
load_text_encoder_in_8bit: bool = False,
|
| 53 |
+
) -> None:
|
| 54 |
+
"""Run the preprocessing pipeline with the given arguments."""
|
| 55 |
+
# Validate dataset file
|
| 56 |
+
_validate_dataset_file(dataset_file)
|
| 57 |
+
|
| 58 |
+
# Set up output directories
|
| 59 |
+
output_base = Path(output_dir) if output_dir else Path(dataset_file).parent / ".precomputed"
|
| 60 |
+
conditions_dir = output_base / "conditions"
|
| 61 |
+
latents_dir = output_base / "latents"
|
| 62 |
+
|
| 63 |
+
if lora_trigger:
|
| 64 |
+
logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions')
|
| 65 |
+
|
| 66 |
+
with free_gpu_memory_context():
|
| 67 |
+
# Process captions using the dedicated function
|
| 68 |
+
compute_captions_embeddings(
|
| 69 |
+
dataset_file=dataset_file,
|
| 70 |
+
output_dir=str(conditions_dir),
|
| 71 |
+
model_path=model_path,
|
| 72 |
+
text_encoder_path=text_encoder_path,
|
| 73 |
+
caption_column=caption_column,
|
| 74 |
+
media_column=video_column,
|
| 75 |
+
lora_trigger=lora_trigger,
|
| 76 |
+
remove_llm_prefixes=remove_llm_prefixes,
|
| 77 |
+
batch_size=batch_size,
|
| 78 |
+
device=device,
|
| 79 |
+
load_in_8bit=load_text_encoder_in_8bit,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Process videos using the dedicated function
|
| 83 |
+
audio_latents_dir = None
|
| 84 |
+
if with_audio:
|
| 85 |
+
logger.info("Audio preprocessing enabled - will extract and encode audio from videos")
|
| 86 |
+
audio_latents_dir = output_base / "audio_latents"
|
| 87 |
+
|
| 88 |
+
with free_gpu_memory_context():
|
| 89 |
+
compute_latents(
|
| 90 |
+
dataset_file=dataset_file,
|
| 91 |
+
video_column=video_column,
|
| 92 |
+
resolution_buckets=resolution_buckets,
|
| 93 |
+
output_dir=str(latents_dir),
|
| 94 |
+
model_path=model_path,
|
| 95 |
+
batch_size=batch_size,
|
| 96 |
+
device=device,
|
| 97 |
+
vae_tiling=vae_tiling,
|
| 98 |
+
with_audio=with_audio,
|
| 99 |
+
audio_output_dir=str(audio_latents_dir) if audio_latents_dir else None,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Process reference videos if reference_column is provided
|
| 103 |
+
if reference_column:
|
| 104 |
+
# Validate: scaled references with multiple buckets can cause ambiguous bucket matching
|
| 105 |
+
if reference_downscale_factor > 1 and len(resolution_buckets) > 1:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"When using --reference-downscale-factor > 1, only a single resolution bucket is supported. "
|
| 108 |
+
"Using multiple buckets with scaled references can cause ambiguous bucket matching "
|
| 109 |
+
"(e.g., a 512x256 reference could match either the scaled-down 1024x512 bucket or the 512x256 "
|
| 110 |
+
"bucket). Please use a single resolution bucket or set --reference-downscale-factor to 1."
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Calculate and validate scaled resolution buckets for reference videos
|
| 114 |
+
reference_buckets = compute_scaled_resolution_buckets(resolution_buckets, reference_downscale_factor)
|
| 115 |
+
|
| 116 |
+
if reference_downscale_factor > 1:
|
| 117 |
+
logger.info(
|
| 118 |
+
f"Processing reference videos for IC-LoRA training at 1/{reference_downscale_factor} resolution..."
|
| 119 |
+
)
|
| 120 |
+
logger.info(f"Reference resolution buckets: {reference_buckets}")
|
| 121 |
+
else:
|
| 122 |
+
logger.info("Processing reference videos for IC-LoRA training...")
|
| 123 |
+
|
| 124 |
+
reference_latents_dir = output_base / "reference_latents"
|
| 125 |
+
|
| 126 |
+
compute_latents(
|
| 127 |
+
dataset_file=dataset_file,
|
| 128 |
+
main_media_column=video_column,
|
| 129 |
+
video_column=reference_column,
|
| 130 |
+
resolution_buckets=reference_buckets,
|
| 131 |
+
output_dir=str(reference_latents_dir),
|
| 132 |
+
model_path=model_path,
|
| 133 |
+
batch_size=batch_size,
|
| 134 |
+
device=device,
|
| 135 |
+
vae_tiling=vae_tiling,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Handle decoding if requested (for verification)
|
| 139 |
+
if decode:
|
| 140 |
+
logger.info("Decoding latents for verification...")
|
| 141 |
+
|
| 142 |
+
decoder = LatentsDecoder(
|
| 143 |
+
model_path=model_path,
|
| 144 |
+
device=device,
|
| 145 |
+
vae_tiling=vae_tiling,
|
| 146 |
+
with_audio=with_audio,
|
| 147 |
+
)
|
| 148 |
+
decoder.decode(latents_dir, output_base / "decoded_videos")
|
| 149 |
+
|
| 150 |
+
# Also decode reference videos if they exist
|
| 151 |
+
if reference_column:
|
| 152 |
+
reference_latents_dir = output_base / "reference_latents"
|
| 153 |
+
if reference_latents_dir.exists():
|
| 154 |
+
logger.info("Decoding reference videos...")
|
| 155 |
+
decoder.decode(reference_latents_dir, output_base / "decoded_reference_videos")
|
| 156 |
+
|
| 157 |
+
# Decode audio latents if they exist
|
| 158 |
+
if with_audio and audio_latents_dir and audio_latents_dir.exists():
|
| 159 |
+
logger.info("Decoding audio latents...")
|
| 160 |
+
decoder.decode_audio(audio_latents_dir, output_base / "decoded_audio")
|
| 161 |
+
|
| 162 |
+
# Print summary
|
| 163 |
+
logger.info(f"Dataset preprocessing complete! Results saved to {output_base}")
|
| 164 |
+
if reference_column:
|
| 165 |
+
logger.info("Reference videos processed and saved to reference_latents/ directory for IC-LoRA training")
|
| 166 |
+
if with_audio:
|
| 167 |
+
logger.info("Audio latents saved to audio_latents/ directory for audio-video training")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _validate_dataset_file(dataset_path: str) -> None:
|
| 171 |
+
"""Validate that the dataset file exists and has the correct format."""
|
| 172 |
+
dataset_file = Path(dataset_path)
|
| 173 |
+
|
| 174 |
+
if not dataset_file.exists():
|
| 175 |
+
raise FileNotFoundError(f"Dataset file does not exist: {dataset_file}")
|
| 176 |
+
|
| 177 |
+
if not dataset_file.is_file():
|
| 178 |
+
raise ValueError(f"Dataset path must be a file, not a directory: {dataset_file}")
|
| 179 |
+
|
| 180 |
+
if dataset_file.suffix.lower() not in [".csv", ".json", ".jsonl"]:
|
| 181 |
+
raise ValueError(f"Dataset file must be CSV, JSON, or JSONL format: {dataset_file}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@app.command()
|
| 185 |
+
def main( # noqa: PLR0913
|
| 186 |
+
dataset_path: str = typer.Argument(
|
| 187 |
+
...,
|
| 188 |
+
help="Path to metadata file (CSV/JSON/JSONL) containing captions and video paths",
|
| 189 |
+
),
|
| 190 |
+
resolution_buckets: str = typer.Option(
|
| 191 |
+
...,
|
| 192 |
+
help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")',
|
| 193 |
+
),
|
| 194 |
+
model_path: str = typer.Option(
|
| 195 |
+
...,
|
| 196 |
+
help="Path to LTX-2 checkpoint (.safetensors file)",
|
| 197 |
+
),
|
| 198 |
+
text_encoder_path: str = typer.Option(
|
| 199 |
+
...,
|
| 200 |
+
help="Path to Gemma text encoder directory",
|
| 201 |
+
),
|
| 202 |
+
caption_column: str = typer.Option(
|
| 203 |
+
default="caption",
|
| 204 |
+
help="Column name containing captions in the dataset JSON/JSONL/CSV file",
|
| 205 |
+
),
|
| 206 |
+
video_column: str = typer.Option(
|
| 207 |
+
default="media_path",
|
| 208 |
+
help="Column name containing video paths in the dataset JSON/JSONL/CSV file",
|
| 209 |
+
),
|
| 210 |
+
batch_size: int = typer.Option(
|
| 211 |
+
default=1,
|
| 212 |
+
help="Batch size for preprocessing",
|
| 213 |
+
),
|
| 214 |
+
device: str = typer.Option(
|
| 215 |
+
default="cuda",
|
| 216 |
+
help="Device to use for computation",
|
| 217 |
+
),
|
| 218 |
+
vae_tiling: bool = typer.Option(
|
| 219 |
+
default=False,
|
| 220 |
+
help="Enable VAE tiling for larger video resolutions",
|
| 221 |
+
),
|
| 222 |
+
output_dir: str | None = typer.Option(
|
| 223 |
+
default=None,
|
| 224 |
+
help="Output directory (defaults to .precomputed in dataset directory)",
|
| 225 |
+
),
|
| 226 |
+
lora_trigger: str | None = typer.Option(
|
| 227 |
+
default=None,
|
| 228 |
+
help="Optional trigger word to prepend to each caption (activates the LoRA during inference)",
|
| 229 |
+
),
|
| 230 |
+
decode: bool = typer.Option(
|
| 231 |
+
default=False,
|
| 232 |
+
help="Decode and save latents after encoding (videos and audio) for verification",
|
| 233 |
+
),
|
| 234 |
+
remove_llm_prefixes: bool = typer.Option(
|
| 235 |
+
default=False,
|
| 236 |
+
help="Remove LLM prefixes from captions",
|
| 237 |
+
),
|
| 238 |
+
reference_column: str | None = typer.Option(
|
| 239 |
+
default=None,
|
| 240 |
+
help="Column name containing reference video paths (for video-to-video training)",
|
| 241 |
+
),
|
| 242 |
+
with_audio: bool = typer.Option(
|
| 243 |
+
default=False,
|
| 244 |
+
help="Extract and encode audio from video files",
|
| 245 |
+
),
|
| 246 |
+
load_text_encoder_in_8bit: bool = typer.Option(
|
| 247 |
+
default=False,
|
| 248 |
+
help="Load the Gemma text encoder in 8-bit precision to save GPU memory (requires bitsandbytes)",
|
| 249 |
+
),
|
| 250 |
+
reference_downscale_factor: int = typer.Option(
|
| 251 |
+
default=1,
|
| 252 |
+
help="Downscale factor for reference video resolution. When > 1, reference videos are processed at "
|
| 253 |
+
"1/n resolution (e.g., 2 means half resolution). Used for efficient IC-LoRA training.",
|
| 254 |
+
),
|
| 255 |
+
) -> None:
|
| 256 |
+
"""Preprocess a video dataset by computing and saving latents and text embeddings.
|
| 257 |
+
The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.
|
| 258 |
+
This script is designed for LTX-2 models which use the Gemma text encoder.
|
| 259 |
+
Examples:
|
| 260 |
+
# Process a dataset with LTX-2 model
|
| 261 |
+
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
|
| 262 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma
|
| 263 |
+
# Process dataset with custom column names
|
| 264 |
+
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
|
| 265 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
|
| 266 |
+
--caption-column "text" --video-column "video_path"
|
| 267 |
+
# Process dataset with reference videos for IC-LoRA training
|
| 268 |
+
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
|
| 269 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
|
| 270 |
+
--reference-column "reference_path"
|
| 271 |
+
# Process dataset with scaled reference videos (half resolution) for efficient IC-LoRA
|
| 272 |
+
python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
|
| 273 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
|
| 274 |
+
--reference-column "reference_path" --reference-downscale-factor 2
|
| 275 |
+
# Process dataset with audio for audio-video training
|
| 276 |
+
python scripts/process_dataset.py dataset.json --resolution-buckets 768x512x97 \\
|
| 277 |
+
--model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
|
| 278 |
+
--with-audio
|
| 279 |
+
"""
|
| 280 |
+
parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets)
|
| 281 |
+
|
| 282 |
+
if len(parsed_resolution_buckets) > 1:
|
| 283 |
+
logger.warning(
|
| 284 |
+
"Using multiple resolution buckets. "
|
| 285 |
+
"When training with multiple resolution buckets, you must use a batch size of 1."
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Validate reference_downscale_factor
|
| 289 |
+
if reference_downscale_factor < 1:
|
| 290 |
+
raise typer.BadParameter("--reference-downscale-factor must be >= 1")
|
| 291 |
+
|
| 292 |
+
if reference_downscale_factor > 1 and not reference_column:
|
| 293 |
+
logger.warning("--reference-downscale-factor specified but no --reference-column provided. Ignoring.")
|
| 294 |
+
|
| 295 |
+
preprocess_dataset(
|
| 296 |
+
dataset_file=dataset_path,
|
| 297 |
+
caption_column=caption_column,
|
| 298 |
+
video_column=video_column,
|
| 299 |
+
resolution_buckets=parsed_resolution_buckets,
|
| 300 |
+
batch_size=batch_size,
|
| 301 |
+
output_dir=output_dir,
|
| 302 |
+
lora_trigger=lora_trigger,
|
| 303 |
+
vae_tiling=vae_tiling,
|
| 304 |
+
decode=decode,
|
| 305 |
+
model_path=model_path,
|
| 306 |
+
text_encoder_path=text_encoder_path,
|
| 307 |
+
device=device,
|
| 308 |
+
remove_llm_prefixes=remove_llm_prefixes,
|
| 309 |
+
reference_column=reference_column,
|
| 310 |
+
reference_downscale_factor=reference_downscale_factor,
|
| 311 |
+
with_audio=with_audio,
|
| 312 |
+
load_text_encoder_in_8bit=load_text_encoder_in_8bit,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
app()
|
packages/ltx-trainer/scripts/process_videos.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Compute latent representations for video generation training.
|
| 5 |
+
This module provides functionality for processing video and image files, including:
|
| 6 |
+
- Loading videos/images from various file formats (CSV, JSON, JSONL)
|
| 7 |
+
- Resizing, cropping, and transforming media
|
| 8 |
+
- MediaDataset for video-only preprocessing workflows
|
| 9 |
+
- BucketSampler for grouping videos by resolution
|
| 10 |
+
Can be used as a standalone script:
|
| 11 |
+
python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \
|
| 12 |
+
--output-dir /path/to/output --model-source /path/to/ltx2.safetensors
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import math
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import torch
|
| 23 |
+
import torchaudio
|
| 24 |
+
import typer
|
| 25 |
+
from pillow_heif import register_heif_opener
|
| 26 |
+
from rich.console import Console
|
| 27 |
+
from rich.progress import (
|
| 28 |
+
BarColumn,
|
| 29 |
+
MofNCompleteColumn,
|
| 30 |
+
Progress,
|
| 31 |
+
SpinnerColumn,
|
| 32 |
+
TaskProgressColumn,
|
| 33 |
+
TextColumn,
|
| 34 |
+
TimeElapsedColumn,
|
| 35 |
+
TimeRemainingColumn,
|
| 36 |
+
)
|
| 37 |
+
from torch.utils.data import DataLoader, Dataset
|
| 38 |
+
from torchvision import transforms
|
| 39 |
+
from torchvision.transforms import InterpolationMode
|
| 40 |
+
from torchvision.transforms.functional import crop, resize, to_tensor
|
| 41 |
+
from transformers.utils.logging import disable_progress_bar
|
| 42 |
+
|
| 43 |
+
from ltx_core.model.audio_vae import AudioProcessor
|
| 44 |
+
from ltx_core.types import Audio
|
| 45 |
+
from ltx_trainer import logger
|
| 46 |
+
from ltx_trainer.model_loader import load_audio_vae_encoder, load_video_vae_encoder
|
| 47 |
+
from ltx_trainer.utils import open_image_as_srgb
|
| 48 |
+
from ltx_trainer.video_utils import get_video_frame_count, read_video
|
| 49 |
+
|
| 50 |
+
disable_progress_bar()
|
| 51 |
+
|
| 52 |
+
# Register HEIF/HEIC support
|
| 53 |
+
register_heif_opener()
|
| 54 |
+
|
| 55 |
+
# Constants for validation
|
| 56 |
+
VAE_SPATIAL_FACTOR = 32
|
| 57 |
+
VAE_TEMPORAL_FACTOR = 8
|
| 58 |
+
|
| 59 |
+
# Audio constants
|
| 60 |
+
AUDIO_LATENT_CHANNELS = 8
|
| 61 |
+
AUDIO_FREQUENCY_BINS = 16
|
| 62 |
+
|
| 63 |
+
DEFAULT_TILE_SIZE = 512 # Spatial tile size in pixels (must be ≥64 and divisible by 32)
|
| 64 |
+
DEFAULT_TILE_OVERLAP = 128 # Spatial tile overlap in pixels (must be divisible by 32)
|
| 65 |
+
|
| 66 |
+
app = typer.Typer(
|
| 67 |
+
pretty_exceptions_enable=False,
|
| 68 |
+
no_args_is_help=True,
|
| 69 |
+
help="Process videos/images and save latent representations for video generation training.",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MediaDataset(Dataset):
|
| 74 |
+
"""
|
| 75 |
+
Dataset for processing video and image files.
|
| 76 |
+
This dataset is designed for media preprocessing workflows where you need to:
|
| 77 |
+
- Load and preprocess videos/images
|
| 78 |
+
- Apply resizing and cropping transformations
|
| 79 |
+
- Handle different resolution buckets
|
| 80 |
+
- Filter out invalid media files
|
| 81 |
+
- Optionally extract audio from video files
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
dataset_file: str | Path,
|
| 87 |
+
main_media_column: str,
|
| 88 |
+
video_column: str,
|
| 89 |
+
resolution_buckets: list[tuple[int, int, int]],
|
| 90 |
+
reshape_mode: str = "center",
|
| 91 |
+
with_audio: bool = False,
|
| 92 |
+
) -> None:
|
| 93 |
+
"""
|
| 94 |
+
Initialize the media dataset.
|
| 95 |
+
Args:
|
| 96 |
+
dataset_file: Path to CSV/JSON/JSONL metadata file
|
| 97 |
+
video_column: Column name for video paths in the metadata file
|
| 98 |
+
resolution_buckets: List of (frames, height, width) tuples
|
| 99 |
+
reshape_mode: How to crop videos ("center", "random")
|
| 100 |
+
with_audio: Whether to extract audio from video files
|
| 101 |
+
"""
|
| 102 |
+
super().__init__()
|
| 103 |
+
|
| 104 |
+
self.dataset_file = Path(dataset_file)
|
| 105 |
+
self.main_media_column = main_media_column
|
| 106 |
+
self.resolution_buckets = resolution_buckets
|
| 107 |
+
self.reshape_mode = reshape_mode
|
| 108 |
+
self.with_audio = with_audio
|
| 109 |
+
|
| 110 |
+
# First load main media paths
|
| 111 |
+
self.main_media_paths = self._load_video_paths(main_media_column)
|
| 112 |
+
|
| 113 |
+
# Then load reference video paths
|
| 114 |
+
self.video_paths = self._load_video_paths(video_column)
|
| 115 |
+
|
| 116 |
+
# Filter out videos with insufficient frames
|
| 117 |
+
self._filter_valid_videos()
|
| 118 |
+
|
| 119 |
+
self.max_target_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
|
| 120 |
+
|
| 121 |
+
# Set up video transforms
|
| 122 |
+
self.transforms = transforms.Compose(
|
| 123 |
+
[
|
| 124 |
+
transforms.Lambda(lambda x: x.clamp_(0, 1)),
|
| 125 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def __len__(self) -> int:
|
| 130 |
+
return len(self.video_paths)
|
| 131 |
+
|
| 132 |
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 133 |
+
"""Get a single video/image with metadata, and optionally audio."""
|
| 134 |
+
if isinstance(index, list):
|
| 135 |
+
# Special case for BucketSampler - return cached data
|
| 136 |
+
return index
|
| 137 |
+
|
| 138 |
+
video_path: Path = self.video_paths[index]
|
| 139 |
+
|
| 140 |
+
# Compute relative path of the video
|
| 141 |
+
data_root = self.dataset_file.parent
|
| 142 |
+
relative_path = str(video_path.relative_to(data_root))
|
| 143 |
+
media_relative_path = str(self.main_media_paths[index].relative_to(data_root))
|
| 144 |
+
|
| 145 |
+
if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
| 146 |
+
media_tensor = self._preprocess_image(video_path)
|
| 147 |
+
fps = 1.0
|
| 148 |
+
audio_data = None # Images don't have audio
|
| 149 |
+
else:
|
| 150 |
+
media_tensor, fps = self._preprocess_video(video_path)
|
| 151 |
+
|
| 152 |
+
# Extract audio if enabled
|
| 153 |
+
if self.with_audio:
|
| 154 |
+
# Calculate target duration from the processed video frames
|
| 155 |
+
# This ensures audio is trimmed to match the exact video duration
|
| 156 |
+
# media_tensor is [C, F, H, W] so shape[1] is num_frames
|
| 157 |
+
target_duration = media_tensor.shape[1] / fps
|
| 158 |
+
audio_data = self._extract_audio(video_path, target_duration)
|
| 159 |
+
else:
|
| 160 |
+
audio_data = None
|
| 161 |
+
|
| 162 |
+
# media_tensor is [C, F, H, W] format for VAE compatibility
|
| 163 |
+
_, num_frames, height, width = media_tensor.shape
|
| 164 |
+
|
| 165 |
+
result = {
|
| 166 |
+
"video": media_tensor,
|
| 167 |
+
"relative_path": relative_path,
|
| 168 |
+
"main_media_relative_path": media_relative_path,
|
| 169 |
+
"video_metadata": {
|
| 170 |
+
"num_frames": num_frames,
|
| 171 |
+
"height": height,
|
| 172 |
+
"width": width,
|
| 173 |
+
"fps": fps,
|
| 174 |
+
},
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# Add audio data if available
|
| 178 |
+
if audio_data is not None:
|
| 179 |
+
result["audio"] = audio_data
|
| 180 |
+
|
| 181 |
+
return result
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def _extract_audio(video_path: Path, target_duration: float) -> dict[str, torch.Tensor | int] | None:
|
| 185 |
+
"""Extract audio track from a video file, trimmed to match video duration."""
|
| 186 |
+
try:
|
| 187 |
+
# torchaudio can extract audio from video files directly
|
| 188 |
+
# waveform shape: [channels, samples]
|
| 189 |
+
waveform, sample_rate = torchaudio.load(str(video_path))
|
| 190 |
+
|
| 191 |
+
# Trim or pad to target duration
|
| 192 |
+
target_samples = int(target_duration * sample_rate)
|
| 193 |
+
current_samples = waveform.shape[-1]
|
| 194 |
+
|
| 195 |
+
if current_samples > target_samples:
|
| 196 |
+
# Trim to target duration
|
| 197 |
+
waveform = waveform[..., :target_samples]
|
| 198 |
+
elif current_samples < target_samples:
|
| 199 |
+
# Pad with zeros to target duration
|
| 200 |
+
padding = target_samples - current_samples
|
| 201 |
+
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
| 202 |
+
logger.warning(f"Padded audio to {target_duration:.2f} seconds for {video_path}")
|
| 203 |
+
|
| 204 |
+
return {"waveform": waveform, "sample_rate": sample_rate}
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.debug(f"Could not extract audio from {video_path}: {e}")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
def _load_video_paths(self, column: str) -> list[Path]:
|
| 211 |
+
"""Load video paths from the specified data source."""
|
| 212 |
+
if self.dataset_file.suffix == ".csv":
|
| 213 |
+
return self._load_video_paths_from_csv(column)
|
| 214 |
+
elif self.dataset_file.suffix == ".json":
|
| 215 |
+
return self._load_video_paths_from_json(column)
|
| 216 |
+
elif self.dataset_file.suffix == ".jsonl":
|
| 217 |
+
return self._load_video_paths_from_jsonl(column)
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.")
|
| 220 |
+
|
| 221 |
+
def _load_video_paths_from_csv(self, column: str) -> list[Path]:
|
| 222 |
+
"""Load video paths from a CSV file."""
|
| 223 |
+
df = pd.read_csv(self.dataset_file)
|
| 224 |
+
if column not in df.columns:
|
| 225 |
+
raise ValueError(f"Column '{column}' not found in CSV file")
|
| 226 |
+
|
| 227 |
+
data_root = self.dataset_file.parent
|
| 228 |
+
video_paths = [data_root / Path(line.strip()) for line in df[column].tolist()]
|
| 229 |
+
|
| 230 |
+
# Validate that all paths exist
|
| 231 |
+
invalid_paths = [path for path in video_paths if not path.is_file()]
|
| 232 |
+
if invalid_paths:
|
| 233 |
+
raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}")
|
| 234 |
+
|
| 235 |
+
return video_paths
|
| 236 |
+
|
| 237 |
+
def _load_video_paths_from_json(self, column: str) -> list[Path]:
|
| 238 |
+
"""Load video paths from a JSON file."""
|
| 239 |
+
with open(self.dataset_file, "r", encoding="utf-8") as file:
|
| 240 |
+
data = json.load(file)
|
| 241 |
+
|
| 242 |
+
if not isinstance(data, list):
|
| 243 |
+
raise ValueError("JSON file must contain a list of objects")
|
| 244 |
+
|
| 245 |
+
data_root = self.dataset_file.parent
|
| 246 |
+
video_paths = []
|
| 247 |
+
for entry in data:
|
| 248 |
+
if column not in entry:
|
| 249 |
+
raise ValueError(f"Key '{column}' not found in JSON entry")
|
| 250 |
+
video_paths.append(data_root / Path(entry[column].strip()))
|
| 251 |
+
|
| 252 |
+
# Validate that all paths exist
|
| 253 |
+
invalid_paths = [path for path in video_paths if not path.is_file()]
|
| 254 |
+
if invalid_paths:
|
| 255 |
+
raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}")
|
| 256 |
+
|
| 257 |
+
return video_paths
|
| 258 |
+
|
| 259 |
+
def _load_video_paths_from_jsonl(self, column: str) -> list[Path]:
|
| 260 |
+
"""Load video paths from a JSONL file."""
|
| 261 |
+
data_root = self.dataset_file.parent
|
| 262 |
+
video_paths = []
|
| 263 |
+
with open(self.dataset_file, "r", encoding="utf-8") as file:
|
| 264 |
+
for line in file:
|
| 265 |
+
entry = json.loads(line)
|
| 266 |
+
if column not in entry:
|
| 267 |
+
raise ValueError(f"Key '{column}' not found in JSONL entry")
|
| 268 |
+
video_paths.append(data_root / Path(entry[column].strip()))
|
| 269 |
+
|
| 270 |
+
# Validate that all paths exist
|
| 271 |
+
invalid_paths = [path for path in video_paths if not path.is_file()]
|
| 272 |
+
if invalid_paths:
|
| 273 |
+
raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}")
|
| 274 |
+
|
| 275 |
+
return video_paths
|
| 276 |
+
|
| 277 |
+
def _filter_valid_videos(self) -> None:
|
| 278 |
+
"""Filter out videos with insufficient frames."""
|
| 279 |
+
original_length = len(self.video_paths)
|
| 280 |
+
valid_video_paths = []
|
| 281 |
+
valid_main_media_paths = []
|
| 282 |
+
min_frames_required = min(self.resolution_buckets, key=lambda x: x[0])[0]
|
| 283 |
+
|
| 284 |
+
for i, video_path in enumerate(self.video_paths):
|
| 285 |
+
if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
| 286 |
+
valid_video_paths.append(video_path)
|
| 287 |
+
valid_main_media_paths.append(self.main_media_paths[i])
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
frame_count = get_video_frame_count(video_path)
|
| 292 |
+
|
| 293 |
+
if frame_count >= min_frames_required:
|
| 294 |
+
valid_video_paths.append(video_path)
|
| 295 |
+
valid_main_media_paths.append(self.main_media_paths[i])
|
| 296 |
+
else:
|
| 297 |
+
logger.warning(
|
| 298 |
+
f"Skipping video at {video_path} - has {frame_count} frames, "
|
| 299 |
+
f"which is less than the minimum required frames ({min_frames_required})"
|
| 300 |
+
)
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.warning(f"Failed to read video at {video_path}: {e!s}")
|
| 303 |
+
|
| 304 |
+
# Update both path lists to maintain synchronization
|
| 305 |
+
self.video_paths = valid_video_paths
|
| 306 |
+
self.main_media_paths = valid_main_media_paths
|
| 307 |
+
|
| 308 |
+
if len(self.video_paths) < original_length:
|
| 309 |
+
logger.warning(
|
| 310 |
+
f"Filtered out {original_length - len(self.video_paths)} videos with insufficient frames. "
|
| 311 |
+
f"Proceeding with {len(self.video_paths)} valid videos."
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def _preprocess_image(self, path: Path) -> torch.Tensor:
|
| 315 |
+
"""Preprocess a single image by resizing and applying transforms."""
|
| 316 |
+
image = open_image_as_srgb(path)
|
| 317 |
+
image = to_tensor(image)
|
| 318 |
+
image = image.unsqueeze(0) # Add frame dimension [1, C, H, W] for bucket selection
|
| 319 |
+
|
| 320 |
+
# Find nearest resolution bucket and resize
|
| 321 |
+
nearest_bucket = self._get_resolution_bucket_for_item(image)
|
| 322 |
+
_, target_height, target_width = nearest_bucket
|
| 323 |
+
image_resized = self._resize_and_crop(image, target_height, target_width)
|
| 324 |
+
# _resize_and_crop returns [C, H, W] for single-frame input (squeeze removes dim 0)
|
| 325 |
+
|
| 326 |
+
# Apply transforms
|
| 327 |
+
image = self.transforms(image_resized) # [C, H, W] -> [C, H, W]
|
| 328 |
+
|
| 329 |
+
# Add frame dimension in VAE format: [C, H, W] -> [C, 1, H, W]
|
| 330 |
+
image = image.unsqueeze(1)
|
| 331 |
+
return image
|
| 332 |
+
|
| 333 |
+
def _preprocess_video(self, path: Path) -> tuple[torch.Tensor, float]:
|
| 334 |
+
"""Preprocess a video by loading, resizing, and applying transforms.
|
| 335 |
+
Returns:
|
| 336 |
+
Tuple of (video tensor in [C, F, H, W] format, fps)
|
| 337 |
+
"""
|
| 338 |
+
# Load video frames up to max_target_frames
|
| 339 |
+
video, fps = read_video(path, max_frames=self.max_target_frames)
|
| 340 |
+
|
| 341 |
+
nearest_bucket = self._get_resolution_bucket_for_item(video)
|
| 342 |
+
target_num_frames, target_height, target_width = nearest_bucket
|
| 343 |
+
frames_resized = self._resize_and_crop(video, target_height, target_width)
|
| 344 |
+
|
| 345 |
+
# Trim video to target number of frames
|
| 346 |
+
frames_resized = frames_resized[:target_num_frames]
|
| 347 |
+
|
| 348 |
+
# Apply transforms to each frame and stack
|
| 349 |
+
video = torch.stack([self.transforms(frame) for frame in frames_resized], dim=0)
|
| 350 |
+
|
| 351 |
+
# Permute [F,C,H,W] -> [C,F,H,W] for VAE compatibility
|
| 352 |
+
# After DataLoader batching, this becomes [B,C,F,H,W] which VAE expects
|
| 353 |
+
video = video.permute(1, 0, 2, 3).contiguous()
|
| 354 |
+
|
| 355 |
+
return video, fps
|
| 356 |
+
|
| 357 |
+
def _get_resolution_bucket_for_item(self, media_tensor: torch.Tensor) -> tuple[int, int, int]:
|
| 358 |
+
"""Get the nearest resolution bucket for the given media tensor."""
|
| 359 |
+
num_frames, _, height, width = media_tensor.shape
|
| 360 |
+
|
| 361 |
+
def distance(bucket: tuple[int, int, int]) -> tuple:
|
| 362 |
+
bucket_num_frames, bucket_height, bucket_width = bucket
|
| 363 |
+
# Lexicographic key:
|
| 364 |
+
# 1) minimize aspect-ratio diff (in log-scale, for invariance to shorter/longer ARs)
|
| 365 |
+
# 2) prefer buckets with more frames (by using negative)
|
| 366 |
+
# 3) prefer buckets with larger spatial area (by using negative)
|
| 367 |
+
return (
|
| 368 |
+
abs(math.log(width / height) - math.log(bucket_width / bucket_height)),
|
| 369 |
+
-bucket_num_frames,
|
| 370 |
+
-(bucket_height * bucket_width),
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# Keep only buckets with <= available frames
|
| 374 |
+
relevant_buckets = [b for b in self.resolution_buckets if b[0] <= num_frames]
|
| 375 |
+
if not relevant_buckets:
|
| 376 |
+
raise ValueError(f"No resolution buckets have <= {num_frames} frames. Available: {self.resolution_buckets}")
|
| 377 |
+
|
| 378 |
+
# Find the bucket with the minimal distance (according to the function above) to the media item's shape.
|
| 379 |
+
nearest_bucket = min(relevant_buckets, key=distance)
|
| 380 |
+
|
| 381 |
+
return nearest_bucket
|
| 382 |
+
|
| 383 |
+
def _resize_and_crop(self, media_tensor: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor:
|
| 384 |
+
"""Resize and crop tensor to target size."""
|
| 385 |
+
# Get current dimensions
|
| 386 |
+
current_height, current_width = media_tensor.shape[2], media_tensor.shape[3]
|
| 387 |
+
|
| 388 |
+
# Calculate aspect ratios to determine which dimension to resize first
|
| 389 |
+
current_aspect = current_width / current_height
|
| 390 |
+
target_aspect = target_width / target_height
|
| 391 |
+
|
| 392 |
+
# Resize while maintaining aspect ratio - scale to make the smaller dimension fit
|
| 393 |
+
if current_aspect > target_aspect:
|
| 394 |
+
# Current is wider than target, so scale by height
|
| 395 |
+
new_width = int(current_width * target_height / current_height)
|
| 396 |
+
media_tensor = resize(
|
| 397 |
+
media_tensor,
|
| 398 |
+
size=[target_height, new_width], # type: ignore
|
| 399 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 400 |
+
)
|
| 401 |
+
else:
|
| 402 |
+
# Current is taller than target, so scale by width
|
| 403 |
+
new_height = int(current_height * target_width / current_width)
|
| 404 |
+
media_tensor = resize(
|
| 405 |
+
media_tensor,
|
| 406 |
+
size=[new_height, target_width],
|
| 407 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Update dimensions after resize
|
| 411 |
+
current_height, current_width = media_tensor.shape[2], media_tensor.shape[3]
|
| 412 |
+
media_tensor = media_tensor.squeeze(0)
|
| 413 |
+
|
| 414 |
+
# Calculate how much we need to crop from each dimension
|
| 415 |
+
delta_h = current_height - target_height
|
| 416 |
+
delta_w = current_width - target_width
|
| 417 |
+
|
| 418 |
+
# Determine crop position based on reshape mode
|
| 419 |
+
if self.reshape_mode == "random":
|
| 420 |
+
# Random crop position
|
| 421 |
+
top = np.random.randint(0, delta_h + 1)
|
| 422 |
+
left = np.random.randint(0, delta_w + 1)
|
| 423 |
+
elif self.reshape_mode == "center":
|
| 424 |
+
# Center crop
|
| 425 |
+
top, left = delta_h // 2, delta_w // 2
|
| 426 |
+
else:
|
| 427 |
+
raise ValueError(f"Unsupported reshape mode: {self.reshape_mode}")
|
| 428 |
+
|
| 429 |
+
# Perform the final crop to exact target dimensions
|
| 430 |
+
media_tensor = crop(media_tensor, top=top, left=left, height=target_height, width=target_width)
|
| 431 |
+
return media_tensor
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def compute_latents( # noqa: PLR0913, PLR0915
|
| 435 |
+
dataset_file: str | Path,
|
| 436 |
+
video_column: str,
|
| 437 |
+
resolution_buckets: list[tuple[int, int, int]],
|
| 438 |
+
output_dir: str,
|
| 439 |
+
model_path: str,
|
| 440 |
+
main_media_column: str | None = None,
|
| 441 |
+
reshape_mode: str = "center",
|
| 442 |
+
batch_size: int = 1,
|
| 443 |
+
device: str = "cuda",
|
| 444 |
+
vae_tiling: bool = False,
|
| 445 |
+
with_audio: bool = False,
|
| 446 |
+
audio_output_dir: str | None = None,
|
| 447 |
+
) -> None:
|
| 448 |
+
"""
|
| 449 |
+
Process videos and save latent representations.
|
| 450 |
+
Args:
|
| 451 |
+
dataset_file: Path to metadata file (CSV/JSON/JSONL) containing video paths
|
| 452 |
+
video_column: Column name for video paths in the metadata file
|
| 453 |
+
resolution_buckets: List of (frames, height, width) tuples
|
| 454 |
+
output_dir: Directory to save video latents
|
| 455 |
+
model_path: Path to LTX-2 checkpoint (.safetensors)
|
| 456 |
+
reshape_mode: How to crop videos ("center", "random")
|
| 457 |
+
main_media_column: Column name for main media paths (if different from video_column)
|
| 458 |
+
batch_size: Batch size for processing
|
| 459 |
+
device: Device to use for computation
|
| 460 |
+
vae_tiling: Whether to enable VAE tiling
|
| 461 |
+
with_audio: Whether to extract and encode audio from videos
|
| 462 |
+
audio_output_dir: Directory to save audio latents (required if with_audio=True)
|
| 463 |
+
"""
|
| 464 |
+
# Validate audio parameters
|
| 465 |
+
if with_audio and audio_output_dir is None:
|
| 466 |
+
raise ValueError("audio_output_dir must be provided when with_audio=True")
|
| 467 |
+
|
| 468 |
+
console = Console()
|
| 469 |
+
torch_device = torch.device(device)
|
| 470 |
+
|
| 471 |
+
# Create dataset
|
| 472 |
+
dataset = MediaDataset(
|
| 473 |
+
dataset_file=dataset_file,
|
| 474 |
+
main_media_column=main_media_column or video_column,
|
| 475 |
+
video_column=video_column,
|
| 476 |
+
resolution_buckets=resolution_buckets,
|
| 477 |
+
reshape_mode=reshape_mode,
|
| 478 |
+
with_audio=with_audio,
|
| 479 |
+
)
|
| 480 |
+
logger.info(f"Loaded {len(dataset)} valid media files")
|
| 481 |
+
|
| 482 |
+
output_path = Path(output_dir)
|
| 483 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 484 |
+
|
| 485 |
+
# Set up audio output directory if needed
|
| 486 |
+
audio_output_path = None
|
| 487 |
+
if with_audio:
|
| 488 |
+
audio_output_path = Path(audio_output_dir)
|
| 489 |
+
audio_output_path.mkdir(parents=True, exist_ok=True)
|
| 490 |
+
|
| 491 |
+
# Load video VAE encoder
|
| 492 |
+
with console.status(f"[bold]Loading video VAE encoder from [cyan]{model_path}[/]...", spinner="dots"):
|
| 493 |
+
vae = load_video_vae_encoder(model_path, device=torch_device, dtype=torch.bfloat16)
|
| 494 |
+
|
| 495 |
+
# Load audio VAE encoder and audio processor if needed
|
| 496 |
+
audio_vae_encoder = None
|
| 497 |
+
audio_processor = None
|
| 498 |
+
if with_audio:
|
| 499 |
+
with console.status(f"[bold]Loading audio VAE encoder from [cyan]{model_path}[/]...", spinner="dots"):
|
| 500 |
+
audio_vae_encoder = load_audio_vae_encoder(
|
| 501 |
+
checkpoint_path=model_path,
|
| 502 |
+
device=torch_device,
|
| 503 |
+
dtype=torch.float32, # Audio VAE needs float32 for quality. TODO: re-test with bfloat16.
|
| 504 |
+
)
|
| 505 |
+
# Create audio processor for waveform-to-spectrogram conversion
|
| 506 |
+
audio_processor = AudioProcessor(
|
| 507 |
+
target_sample_rate=audio_vae_encoder.sample_rate,
|
| 508 |
+
mel_bins=audio_vae_encoder.mel_bins,
|
| 509 |
+
mel_hop_length=audio_vae_encoder.mel_hop_length,
|
| 510 |
+
n_fft=audio_vae_encoder.n_fft,
|
| 511 |
+
).to(torch_device)
|
| 512 |
+
|
| 513 |
+
# Create dataloader
|
| 514 |
+
# Note: batch_size=1 required when with_audio because audio extraction can fail for some videos,
|
| 515 |
+
# and the default collate function can't handle mixed None/dict values across a batch.
|
| 516 |
+
if with_audio and batch_size > 1:
|
| 517 |
+
logger.warning("Audio processing requires batch_size=1. Overriding batch_size to 1.")
|
| 518 |
+
batch_size = 1
|
| 519 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
| 520 |
+
|
| 521 |
+
# Track audio statistics
|
| 522 |
+
audio_success_count = 0
|
| 523 |
+
audio_skip_count = 0
|
| 524 |
+
|
| 525 |
+
# Process batches
|
| 526 |
+
with Progress(
|
| 527 |
+
SpinnerColumn(),
|
| 528 |
+
TextColumn("[progress.description]{task.description}"),
|
| 529 |
+
BarColumn(),
|
| 530 |
+
TaskProgressColumn(),
|
| 531 |
+
MofNCompleteColumn(),
|
| 532 |
+
TimeElapsedColumn(),
|
| 533 |
+
TimeRemainingColumn(),
|
| 534 |
+
console=console,
|
| 535 |
+
) as progress:
|
| 536 |
+
task = progress.add_task("Processing videos", total=len(dataloader))
|
| 537 |
+
|
| 538 |
+
for batch in dataloader:
|
| 539 |
+
# Get video tensor - shape is [B, F, C, H, W] from DataLoader
|
| 540 |
+
video = batch["video"]
|
| 541 |
+
|
| 542 |
+
# Encode video
|
| 543 |
+
with torch.inference_mode():
|
| 544 |
+
video_latent_data = encode_video(vae=vae, video=video, use_tiling=vae_tiling)
|
| 545 |
+
|
| 546 |
+
# Save latents for each item in batch
|
| 547 |
+
for i in range(len(batch["relative_path"])):
|
| 548 |
+
output_rel_path = Path(batch["main_media_relative_path"][i]).with_suffix(".pt")
|
| 549 |
+
output_file = output_path / output_rel_path
|
| 550 |
+
|
| 551 |
+
# Create output directory maintaining structure
|
| 552 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 553 |
+
|
| 554 |
+
# Index into batch to get this item's latents
|
| 555 |
+
latent_data = {
|
| 556 |
+
"latents": video_latent_data["latents"][i].cpu().contiguous(), # [C, F', H', W']
|
| 557 |
+
"num_frames": video_latent_data["num_frames"],
|
| 558 |
+
"height": video_latent_data["height"],
|
| 559 |
+
"width": video_latent_data["width"],
|
| 560 |
+
"fps": batch["video_metadata"]["fps"][i].item(),
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
torch.save(latent_data, output_file)
|
| 564 |
+
|
| 565 |
+
# Process audio if enabled (audio is already extracted by the dataset)
|
| 566 |
+
if with_audio:
|
| 567 |
+
audio_batch = batch.get("audio")
|
| 568 |
+
if audio_batch is not None:
|
| 569 |
+
# Extract the i-th item from batched audio data
|
| 570 |
+
# DataLoader collates [channels, samples] -> [batch, channels, samples]
|
| 571 |
+
audio_data = Audio(
|
| 572 |
+
waveform=audio_batch["waveform"][i],
|
| 573 |
+
sampling_rate=audio_batch["sample_rate"][i].item(),
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Encode audio
|
| 577 |
+
with torch.inference_mode():
|
| 578 |
+
audio_latents = encode_audio(audio_vae_encoder, audio_processor, audio_data)
|
| 579 |
+
|
| 580 |
+
# Save audio latents
|
| 581 |
+
audio_output_file = audio_output_path / output_rel_path
|
| 582 |
+
audio_output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 583 |
+
|
| 584 |
+
audio_save_data = {
|
| 585 |
+
"latents": audio_latents["latents"].cpu().contiguous(),
|
| 586 |
+
"num_time_steps": audio_latents["num_time_steps"],
|
| 587 |
+
"frequency_bins": audio_latents["frequency_bins"],
|
| 588 |
+
"duration": audio_latents["duration"],
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
torch.save(audio_save_data, audio_output_file)
|
| 592 |
+
audio_success_count += 1
|
| 593 |
+
else:
|
| 594 |
+
# Video has no audio track
|
| 595 |
+
audio_skip_count += 1
|
| 596 |
+
|
| 597 |
+
progress.advance(task)
|
| 598 |
+
|
| 599 |
+
# Log summary
|
| 600 |
+
logger.info(f"Processed {len(dataset)} videos. Latents saved to {output_path}")
|
| 601 |
+
if with_audio:
|
| 602 |
+
logger.info(
|
| 603 |
+
f"Audio processing: {audio_success_count} videos with audio, "
|
| 604 |
+
f"{audio_skip_count} videos without audio (skipped)"
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def encode_video(
|
| 609 |
+
vae: torch.nn.Module,
|
| 610 |
+
video: torch.Tensor,
|
| 611 |
+
dtype: torch.dtype | None = None,
|
| 612 |
+
use_tiling: bool = False,
|
| 613 |
+
tile_size: int = DEFAULT_TILE_SIZE,
|
| 614 |
+
tile_overlap: int = DEFAULT_TILE_OVERLAP,
|
| 615 |
+
) -> dict[str, torch.Tensor | int]:
|
| 616 |
+
"""Encode video into non-patchified latent representation.
|
| 617 |
+
Args:
|
| 618 |
+
vae: Video VAE encoder model
|
| 619 |
+
video: Input tensor of shape [B, C, F, H, W] (batch, channels, frames, height, width)
|
| 620 |
+
This is the format expected by the VAE encoder.
|
| 621 |
+
dtype: Target dtype for output latents
|
| 622 |
+
use_tiling: Whether to use spatial tiling for memory efficiency
|
| 623 |
+
tile_size: Tile size in pixels (must be divisible by 32)
|
| 624 |
+
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
|
| 625 |
+
Returns:
|
| 626 |
+
Dict containing non-patchified latents and shape information:
|
| 627 |
+
{
|
| 628 |
+
"latents": Tensor[B, C, F', H', W'], # Non-patchified format with batch dim
|
| 629 |
+
"num_frames": int, # Latent frame count
|
| 630 |
+
"height": int, # Latent height
|
| 631 |
+
"width": int, # Latent width
|
| 632 |
+
}
|
| 633 |
+
"""
|
| 634 |
+
device = next(vae.parameters()).device
|
| 635 |
+
vae_dtype = next(vae.parameters()).dtype
|
| 636 |
+
|
| 637 |
+
# Add batch dimension if needed
|
| 638 |
+
if video.ndim == 4:
|
| 639 |
+
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
|
| 640 |
+
|
| 641 |
+
video = video.to(device=device, dtype=vae_dtype)
|
| 642 |
+
|
| 643 |
+
# Choose encoding method based on tiling flag
|
| 644 |
+
if use_tiling:
|
| 645 |
+
latents = tiled_encode_video(
|
| 646 |
+
vae=vae,
|
| 647 |
+
video=video,
|
| 648 |
+
tile_size=tile_size,
|
| 649 |
+
tile_overlap=tile_overlap,
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
|
| 653 |
+
latents = vae(video)
|
| 654 |
+
|
| 655 |
+
if dtype is not None:
|
| 656 |
+
latents = latents.to(dtype=dtype)
|
| 657 |
+
|
| 658 |
+
_, _, num_frames, height, width = latents.shape
|
| 659 |
+
|
| 660 |
+
return {
|
| 661 |
+
"latents": latents, # [B, C, F', H', W']
|
| 662 |
+
"num_frames": num_frames,
|
| 663 |
+
"height": height,
|
| 664 |
+
"width": width,
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def tiled_encode_video( # noqa: PLR0912, PLR0915
|
| 669 |
+
vae: torch.nn.Module,
|
| 670 |
+
video: torch.Tensor,
|
| 671 |
+
tile_size: int = DEFAULT_TILE_SIZE,
|
| 672 |
+
tile_overlap: int = DEFAULT_TILE_OVERLAP,
|
| 673 |
+
) -> torch.Tensor:
|
| 674 |
+
"""Encode video using spatial tiling for memory efficiency.
|
| 675 |
+
Splits the video into overlapping spatial tiles, encodes each tile separately,
|
| 676 |
+
and blends the results using linear feathering in the overlap regions.
|
| 677 |
+
Args:
|
| 678 |
+
vae: Video VAE encoder model
|
| 679 |
+
video: Input tensor of shape [B, C, F, H, W]
|
| 680 |
+
tile_size: Tile size in pixels (must be divisible by 32)
|
| 681 |
+
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
|
| 682 |
+
Returns:
|
| 683 |
+
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
|
| 684 |
+
"""
|
| 685 |
+
batch, _channels, frames, height, width = video.shape
|
| 686 |
+
device = video.device
|
| 687 |
+
dtype = video.dtype
|
| 688 |
+
|
| 689 |
+
# Validate tile parameters
|
| 690 |
+
if tile_size % VAE_SPATIAL_FACTOR != 0:
|
| 691 |
+
raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
|
| 692 |
+
if tile_overlap % VAE_SPATIAL_FACTOR != 0:
|
| 693 |
+
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
|
| 694 |
+
if tile_overlap >= tile_size:
|
| 695 |
+
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
|
| 696 |
+
|
| 697 |
+
# If video fits in a single tile, use regular encoding
|
| 698 |
+
if height <= tile_size and width <= tile_size:
|
| 699 |
+
return vae(video)
|
| 700 |
+
|
| 701 |
+
# Calculate output dimensions
|
| 702 |
+
# VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
|
| 703 |
+
output_height = height // VAE_SPATIAL_FACTOR
|
| 704 |
+
output_width = width // VAE_SPATIAL_FACTOR
|
| 705 |
+
output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
|
| 706 |
+
|
| 707 |
+
# Latent channels (128 for LTX-2)
|
| 708 |
+
# Get from a small test encode or assume 128
|
| 709 |
+
latent_channels = 128
|
| 710 |
+
|
| 711 |
+
# Initialize output and weight tensors
|
| 712 |
+
output = torch.zeros(
|
| 713 |
+
(batch, latent_channels, output_frames, output_height, output_width),
|
| 714 |
+
device=device,
|
| 715 |
+
dtype=dtype,
|
| 716 |
+
)
|
| 717 |
+
weights = torch.zeros(
|
| 718 |
+
(batch, 1, output_frames, output_height, output_width),
|
| 719 |
+
device=device,
|
| 720 |
+
dtype=dtype,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Calculate tile positions with overlap
|
| 724 |
+
# Step size is tile_size - tile_overlap
|
| 725 |
+
step_h = tile_size - tile_overlap
|
| 726 |
+
step_w = tile_size - tile_overlap
|
| 727 |
+
|
| 728 |
+
h_positions = list(range(0, max(1, height - tile_overlap), step_h))
|
| 729 |
+
w_positions = list(range(0, max(1, width - tile_overlap), step_w))
|
| 730 |
+
|
| 731 |
+
# Ensure last tile covers the edge
|
| 732 |
+
if h_positions[-1] + tile_size < height:
|
| 733 |
+
h_positions.append(height - tile_size)
|
| 734 |
+
if w_positions[-1] + tile_size < width:
|
| 735 |
+
w_positions.append(width - tile_size)
|
| 736 |
+
|
| 737 |
+
# Remove duplicates and sort
|
| 738 |
+
h_positions = sorted(set(h_positions))
|
| 739 |
+
w_positions = sorted(set(w_positions))
|
| 740 |
+
|
| 741 |
+
# Overlap in latent space
|
| 742 |
+
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
|
| 743 |
+
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
|
| 744 |
+
|
| 745 |
+
# Process each tile
|
| 746 |
+
for h_pos in h_positions:
|
| 747 |
+
for w_pos in w_positions:
|
| 748 |
+
# Calculate tile boundaries in input space
|
| 749 |
+
h_start = max(0, h_pos)
|
| 750 |
+
w_start = max(0, w_pos)
|
| 751 |
+
h_end = min(h_start + tile_size, height)
|
| 752 |
+
w_end = min(w_start + tile_size, width)
|
| 753 |
+
|
| 754 |
+
# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
|
| 755 |
+
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
| 756 |
+
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
|
| 757 |
+
|
| 758 |
+
if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
|
| 759 |
+
continue
|
| 760 |
+
|
| 761 |
+
# Adjust end positions
|
| 762 |
+
h_end = h_start + tile_h
|
| 763 |
+
w_end = w_start + tile_w
|
| 764 |
+
|
| 765 |
+
# Extract tile
|
| 766 |
+
tile = video[:, :, :, h_start:h_end, w_start:w_end]
|
| 767 |
+
|
| 768 |
+
# Encode tile
|
| 769 |
+
encoded_tile = vae(tile)
|
| 770 |
+
|
| 771 |
+
# Get actual encoded dimensions
|
| 772 |
+
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
|
| 773 |
+
|
| 774 |
+
# Calculate output positions
|
| 775 |
+
out_h_start = h_start // VAE_SPATIAL_FACTOR
|
| 776 |
+
out_w_start = w_start // VAE_SPATIAL_FACTOR
|
| 777 |
+
out_h_end = min(out_h_start + tile_out_height, output_height)
|
| 778 |
+
out_w_end = min(out_w_start + tile_out_width, output_width)
|
| 779 |
+
|
| 780 |
+
# Trim encoded tile if necessary
|
| 781 |
+
actual_tile_h = out_h_end - out_h_start
|
| 782 |
+
actual_tile_w = out_w_end - out_w_start
|
| 783 |
+
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
|
| 784 |
+
|
| 785 |
+
# Create blending mask with linear feathering at edges
|
| 786 |
+
mask = torch.ones(
|
| 787 |
+
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
|
| 788 |
+
device=device,
|
| 789 |
+
dtype=dtype,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
# Apply feathering at edges (linear blend in overlap regions)
|
| 793 |
+
# Left edge
|
| 794 |
+
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
| 795 |
+
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
| 796 |
+
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
|
| 797 |
+
|
| 798 |
+
# Right edge (bottom in height dimension)
|
| 799 |
+
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
|
| 800 |
+
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
|
| 801 |
+
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
|
| 802 |
+
|
| 803 |
+
# Top edge (left in width dimension)
|
| 804 |
+
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
| 805 |
+
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
| 806 |
+
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
|
| 807 |
+
|
| 808 |
+
# Bottom edge (right in width dimension)
|
| 809 |
+
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
|
| 810 |
+
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
|
| 811 |
+
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
|
| 812 |
+
|
| 813 |
+
# Accumulate weighted results
|
| 814 |
+
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
|
| 815 |
+
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
|
| 816 |
+
|
| 817 |
+
# Normalize by weights (avoid division by zero)
|
| 818 |
+
output = output / (weights + 1e-8)
|
| 819 |
+
|
| 820 |
+
return output
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def encode_audio(
|
| 824 |
+
audio_vae_encoder: torch.nn.Module,
|
| 825 |
+
audio_processor: torch.nn.Module,
|
| 826 |
+
audio: Audio,
|
| 827 |
+
) -> dict[str, torch.Tensor | int | float]:
|
| 828 |
+
"""Encode audio waveform into latent representation.
|
| 829 |
+
Args:
|
| 830 |
+
audio_vae_encoder: Audio VAE encoder model from ltx-core
|
| 831 |
+
audio_processor: AudioProcessor for waveform-to-spectrogram conversion
|
| 832 |
+
audio: Audio container with waveform tensor and sampling rate.
|
| 833 |
+
Returns:
|
| 834 |
+
Dict containing audio latents and shape information:
|
| 835 |
+
{
|
| 836 |
+
"latents": Tensor[C, T, F], # Non-patchified format
|
| 837 |
+
"num_time_steps": int,
|
| 838 |
+
"frequency_bins": int,
|
| 839 |
+
"duration": float,
|
| 840 |
+
}
|
| 841 |
+
"""
|
| 842 |
+
device = next(audio_vae_encoder.parameters()).device
|
| 843 |
+
dtype = next(audio_vae_encoder.parameters()).dtype
|
| 844 |
+
|
| 845 |
+
waveform = audio.waveform.to(device=device, dtype=dtype)
|
| 846 |
+
|
| 847 |
+
# Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
|
| 848 |
+
if waveform.dim() == 2:
|
| 849 |
+
waveform = waveform.unsqueeze(0)
|
| 850 |
+
|
| 851 |
+
# Calculate duration
|
| 852 |
+
duration = waveform.shape[-1] / audio.sampling_rate
|
| 853 |
+
|
| 854 |
+
# Convert waveform to mel spectrogram using AudioProcessor
|
| 855 |
+
mel_spectrogram = audio_processor.waveform_to_mel(Audio(waveform=waveform, sampling_rate=audio.sampling_rate))
|
| 856 |
+
mel_spectrogram = mel_spectrogram.to(dtype=dtype)
|
| 857 |
+
|
| 858 |
+
# Encode mel spectrogram to latents
|
| 859 |
+
latents = audio_vae_encoder(mel_spectrogram)
|
| 860 |
+
|
| 861 |
+
# latents shape: [batch, channels, time, freq] = [1, 8, T, 16]
|
| 862 |
+
_, _channels, time_steps, freq_bins = latents.shape
|
| 863 |
+
|
| 864 |
+
return {
|
| 865 |
+
"latents": latents.squeeze(0), # [C, T, F] - remove batch dim
|
| 866 |
+
"num_time_steps": time_steps,
|
| 867 |
+
"frequency_bins": freq_bins,
|
| 868 |
+
"duration": duration,
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def parse_resolution_buckets(resolution_buckets_str: str) -> list[tuple[int, int, int]]:
|
| 873 |
+
"""Parse resolution buckets from string format to list of tuples (frames, height, width)"""
|
| 874 |
+
resolution_buckets = []
|
| 875 |
+
for bucket_str in resolution_buckets_str.split(";"):
|
| 876 |
+
w, h, f = map(int, bucket_str.split("x"))
|
| 877 |
+
|
| 878 |
+
if w % VAE_SPATIAL_FACTOR != 0 or h % VAE_SPATIAL_FACTOR != 0:
|
| 879 |
+
raise typer.BadParameter(
|
| 880 |
+
f"Width and height must be multiples of {VAE_SPATIAL_FACTOR}, got {w}x{h}",
|
| 881 |
+
param_hint="resolution-buckets",
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
if f % VAE_TEMPORAL_FACTOR != 1:
|
| 885 |
+
raise typer.BadParameter(
|
| 886 |
+
f"Number of frames must be a multiple of {VAE_TEMPORAL_FACTOR} plus 1, got {f}",
|
| 887 |
+
param_hint="resolution-buckets",
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
resolution_buckets.append((f, h, w))
|
| 891 |
+
return resolution_buckets
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def compute_scaled_resolution_buckets(
|
| 895 |
+
resolution_buckets: list[tuple[int, int, int]],
|
| 896 |
+
scale_factor: int,
|
| 897 |
+
) -> list[tuple[int, int, int]]:
|
| 898 |
+
"""Compute scaled resolution buckets and validate the results."""
|
| 899 |
+
if scale_factor == 1:
|
| 900 |
+
return resolution_buckets
|
| 901 |
+
|
| 902 |
+
scaled_buckets = []
|
| 903 |
+
for frames, height, width in resolution_buckets:
|
| 904 |
+
# Validate that scale factor evenly divides the dimensions
|
| 905 |
+
if height % scale_factor != 0:
|
| 906 |
+
raise ValueError(
|
| 907 |
+
f"Height {height} is not evenly divisible by scale factor {scale_factor}. "
|
| 908 |
+
f"Choose a scale factor that divides {height} evenly."
|
| 909 |
+
)
|
| 910 |
+
if width % scale_factor != 0:
|
| 911 |
+
raise ValueError(
|
| 912 |
+
f"Width {width} is not evenly divisible by scale factor {scale_factor}. "
|
| 913 |
+
f"Choose a scale factor that divides {width} evenly."
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
scaled_height = height // scale_factor
|
| 917 |
+
scaled_width = width // scale_factor
|
| 918 |
+
|
| 919 |
+
# Validate scaled dimensions are divisible by VAE spatial factor
|
| 920 |
+
if scaled_height % VAE_SPATIAL_FACTOR != 0:
|
| 921 |
+
raise ValueError(
|
| 922 |
+
f"Scaled height {scaled_height} (from {height} / {scale_factor}) "
|
| 923 |
+
f"is not divisible by {VAE_SPATIAL_FACTOR}. "
|
| 924 |
+
f"Choose a different scale factor or adjust your resolution buckets."
|
| 925 |
+
)
|
| 926 |
+
if scaled_width % VAE_SPATIAL_FACTOR != 0:
|
| 927 |
+
raise ValueError(
|
| 928 |
+
f"Scaled width {scaled_width} (from {width} / {scale_factor}) "
|
| 929 |
+
f"is not divisible by {VAE_SPATIAL_FACTOR}. "
|
| 930 |
+
f"Choose a different scale factor or adjust your resolution buckets."
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
scaled_buckets.append((frames, scaled_height, scaled_width))
|
| 934 |
+
|
| 935 |
+
return scaled_buckets
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
@app.command()
|
| 939 |
+
def main( # noqa: PLR0913
|
| 940 |
+
dataset_file: str = typer.Argument(
|
| 941 |
+
...,
|
| 942 |
+
help="Path to metadata file (CSV/JSON/JSONL) containing video paths",
|
| 943 |
+
),
|
| 944 |
+
resolution_buckets: str = typer.Option(
|
| 945 |
+
...,
|
| 946 |
+
help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")',
|
| 947 |
+
),
|
| 948 |
+
output_dir: str = typer.Option(
|
| 949 |
+
...,
|
| 950 |
+
help="Output directory to save video latents",
|
| 951 |
+
),
|
| 952 |
+
model_path: str = typer.Option(
|
| 953 |
+
...,
|
| 954 |
+
help="Path to LTX-2 checkpoint (.safetensors file)",
|
| 955 |
+
),
|
| 956 |
+
video_column: str = typer.Option(
|
| 957 |
+
default="media_path",
|
| 958 |
+
help="Column name in the dataset JSON/JSONL/CSV file containing video paths",
|
| 959 |
+
),
|
| 960 |
+
batch_size: int = typer.Option(
|
| 961 |
+
default=1,
|
| 962 |
+
help="Batch size for processing",
|
| 963 |
+
),
|
| 964 |
+
device: str = typer.Option(
|
| 965 |
+
default="cuda",
|
| 966 |
+
help="Device to use for computation",
|
| 967 |
+
),
|
| 968 |
+
vae_tiling: bool = typer.Option(
|
| 969 |
+
default=False,
|
| 970 |
+
help="Enable VAE tiling for larger video resolutions",
|
| 971 |
+
),
|
| 972 |
+
reshape_mode: str = typer.Option(
|
| 973 |
+
default="center",
|
| 974 |
+
help="How to crop videos: 'center' or 'random'",
|
| 975 |
+
),
|
| 976 |
+
with_audio: bool = typer.Option(
|
| 977 |
+
default=False,
|
| 978 |
+
help="Extract and encode audio from video files",
|
| 979 |
+
),
|
| 980 |
+
audio_output_dir: str | None = typer.Option(
|
| 981 |
+
default=None,
|
| 982 |
+
help="Output directory for audio latents (required if --with-audio is set)",
|
| 983 |
+
),
|
| 984 |
+
) -> None:
|
| 985 |
+
"""Process videos/images and save latent representations for video generation training.
|
| 986 |
+
This script processes videos and images from metadata files and saves latent representations
|
| 987 |
+
that can be used for training video generation models. The output latents will maintain
|
| 988 |
+
the same folder structure and naming as the corresponding media files.
|
| 989 |
+
Examples:
|
| 990 |
+
# Process videos from a CSV file
|
| 991 |
+
python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\
|
| 992 |
+
--output-dir ./latents --model-path /path/to/ltx2.safetensors
|
| 993 |
+
# Process videos from a JSON file with custom video column
|
| 994 |
+
python scripts/process_videos.py dataset.json --resolution-buckets 768x768x25 \\
|
| 995 |
+
--output-dir ./latents --model-path /path/to/ltx2.safetensors --video-column "video_path"
|
| 996 |
+
# Enable VAE tiling to save GPU VRAM
|
| 997 |
+
python scripts/process_videos.py dataset.csv --resolution-buckets 1024x1024x25 \\
|
| 998 |
+
--output-dir ./latents --model-path /path/to/ltx2.safetensors --vae-tiling
|
| 999 |
+
# Process videos with audio
|
| 1000 |
+
python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\
|
| 1001 |
+
--output-dir ./latents --model-path /path/to/ltx2.safetensors \\
|
| 1002 |
+
--with-audio --audio-output-dir ./audio_latents
|
| 1003 |
+
"""
|
| 1004 |
+
|
| 1005 |
+
# Validate dataset file exists
|
| 1006 |
+
if not Path(dataset_file).is_file():
|
| 1007 |
+
raise typer.BadParameter(f"Dataset file not found: {dataset_file}")
|
| 1008 |
+
|
| 1009 |
+
# Validate audio parameters
|
| 1010 |
+
if with_audio and audio_output_dir is None:
|
| 1011 |
+
raise typer.BadParameter("--audio-output-dir is required when --with-audio is set")
|
| 1012 |
+
|
| 1013 |
+
# Parse resolution buckets
|
| 1014 |
+
parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets)
|
| 1015 |
+
|
| 1016 |
+
if len(parsed_resolution_buckets) > 1:
|
| 1017 |
+
logger.warning(
|
| 1018 |
+
"Using multiple resolution buckets. "
|
| 1019 |
+
"When training with multiple resolution buckets, you must use a batch size of 1."
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
# Process latents
|
| 1023 |
+
compute_latents(
|
| 1024 |
+
dataset_file=dataset_file,
|
| 1025 |
+
video_column=video_column,
|
| 1026 |
+
resolution_buckets=parsed_resolution_buckets,
|
| 1027 |
+
output_dir=output_dir,
|
| 1028 |
+
model_path=model_path,
|
| 1029 |
+
reshape_mode=reshape_mode,
|
| 1030 |
+
batch_size=batch_size,
|
| 1031 |
+
device=device,
|
| 1032 |
+
vae_tiling=vae_tiling,
|
| 1033 |
+
with_audio=with_audio,
|
| 1034 |
+
audio_output_dir=audio_output_dir,
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
if __name__ == "__main__":
|
| 1039 |
+
app()
|
packages/ltx-trainer/scripts/split_scenes.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Split video into scenes using PySceneDetect.
|
| 5 |
+
This script provides a command-line interface for splitting videos into scenes using various detection algorithms.
|
| 6 |
+
It supports multiple detection methods, preview image generation, and customizable parameters for fine-tuning
|
| 7 |
+
the scene detection process.
|
| 8 |
+
Basic usage:
|
| 9 |
+
# Split video using default content-based detection
|
| 10 |
+
scenes_split.py input.mp4 output_dir/
|
| 11 |
+
# Save 3 preview images per scene
|
| 12 |
+
scenes_split.py input.mp4 output_dir/ --save-images 3
|
| 13 |
+
# Process specific duration and filter short scenes
|
| 14 |
+
scenes_split.py input.mp4 output_dir/ --duration 60s --filter-shorter-than 2s
|
| 15 |
+
Advanced usage:
|
| 16 |
+
# Content detection with minimum scene length and frame skip
|
| 17 |
+
scenes_split.py input.mp4 output_dir/ --detector content --min-scene-length 30 --frame-skip 2
|
| 18 |
+
# Use adaptive detection with custom detector and detector parameters
|
| 19 |
+
scenes_split.py input.mp4 output_dir/ --detector adaptive --threshold 3.0 --adaptive-window 10
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from enum import Enum
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import List, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import typer
|
| 27 |
+
from scenedetect import (
|
| 28 |
+
AdaptiveDetector,
|
| 29 |
+
ContentDetector,
|
| 30 |
+
HistogramDetector,
|
| 31 |
+
SceneManager,
|
| 32 |
+
ThresholdDetector,
|
| 33 |
+
open_video,
|
| 34 |
+
)
|
| 35 |
+
from scenedetect.frame_timecode import FrameTimecode
|
| 36 |
+
from scenedetect.scene_manager import SceneDetector, write_scene_list_html
|
| 37 |
+
from scenedetect.scene_manager import save_images as save_scene_images
|
| 38 |
+
from scenedetect.stats_manager import StatsManager
|
| 39 |
+
from scenedetect.video_splitter import split_video_ffmpeg
|
| 40 |
+
|
| 41 |
+
app = typer.Typer(no_args_is_help=True, help="Split video into scenes using PySceneDetect.")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DetectorType(str, Enum):
|
| 45 |
+
"""Available scene detection algorithms."""
|
| 46 |
+
|
| 47 |
+
CONTENT = "content" # Detects fast cuts using HSV color space
|
| 48 |
+
ADAPTIVE = "adaptive" # Detects fast two-phase cuts
|
| 49 |
+
THRESHOLD = "threshold" # Detects fast cuts/slow fades in from and out to a given threshold level
|
| 50 |
+
HISTOGRAM = "histogram" # Detects based on YUV histogram differences in adjacent frames
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def create_detector(
|
| 54 |
+
detector_type: DetectorType,
|
| 55 |
+
threshold: Optional[float] = None,
|
| 56 |
+
min_scene_len: Optional[int] = None,
|
| 57 |
+
luma_only: Optional[bool] = None,
|
| 58 |
+
adaptive_window: Optional[int] = None,
|
| 59 |
+
fade_bias: Optional[float] = None,
|
| 60 |
+
) -> SceneDetector:
|
| 61 |
+
"""Create a scene detector based on the specified type and parameters.
|
| 62 |
+
Args:
|
| 63 |
+
detector_type: Type of detector to create
|
| 64 |
+
threshold: Detection threshold (meaning varies by detector)
|
| 65 |
+
min_scene_len: Minimum scene length in frames
|
| 66 |
+
luma_only: If True, only use brightness for content detection
|
| 67 |
+
adaptive_window: Window size for adaptive detection
|
| 68 |
+
fade_bias: Bias for fade in/out detection (-1.0 to 1.0)
|
| 69 |
+
Note: Parameters set to None will use the detector's built-in default values.
|
| 70 |
+
Returns:
|
| 71 |
+
Configured scene detector instance
|
| 72 |
+
"""
|
| 73 |
+
# Set common arguments
|
| 74 |
+
kwargs = {}
|
| 75 |
+
if threshold is not None:
|
| 76 |
+
kwargs["threshold"] = threshold
|
| 77 |
+
|
| 78 |
+
if min_scene_len is not None:
|
| 79 |
+
kwargs["min_scene_len"] = min_scene_len
|
| 80 |
+
|
| 81 |
+
match detector_type:
|
| 82 |
+
case DetectorType.CONTENT:
|
| 83 |
+
if luma_only is not None:
|
| 84 |
+
kwargs["luma_only"] = luma_only
|
| 85 |
+
return ContentDetector(**kwargs)
|
| 86 |
+
case DetectorType.ADAPTIVE:
|
| 87 |
+
if adaptive_window is not None:
|
| 88 |
+
kwargs["window_width"] = adaptive_window
|
| 89 |
+
if luma_only is not None:
|
| 90 |
+
kwargs["luma_only"] = luma_only
|
| 91 |
+
if "threshold" in kwargs:
|
| 92 |
+
# Special case for adaptive detector which uses different param name
|
| 93 |
+
kwargs["adaptive_threshold"] = kwargs.pop("threshold")
|
| 94 |
+
return AdaptiveDetector(**kwargs)
|
| 95 |
+
case DetectorType.THRESHOLD:
|
| 96 |
+
if fade_bias is not None:
|
| 97 |
+
kwargs["fade_bias"] = fade_bias
|
| 98 |
+
return ThresholdDetector(**kwargs)
|
| 99 |
+
case DetectorType.HISTOGRAM:
|
| 100 |
+
return HistogramDetector(**kwargs)
|
| 101 |
+
case _:
|
| 102 |
+
raise ValueError(f"Unknown detector type: {detector_type}")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def validate_output_dir(output_dir: str) -> Path:
|
| 106 |
+
"""Validate and create output directory if it doesn't exist.
|
| 107 |
+
Args:
|
| 108 |
+
output_dir: Path to the output directory
|
| 109 |
+
Returns:
|
| 110 |
+
Path object of the validated output directory
|
| 111 |
+
"""
|
| 112 |
+
path = Path(output_dir)
|
| 113 |
+
|
| 114 |
+
if path.exists() and not path.is_dir():
|
| 115 |
+
raise typer.BadParameter(f"{output_dir} exists but is not a directory")
|
| 116 |
+
|
| 117 |
+
return path
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def parse_timecode(video: any, time_str: Optional[str]) -> Optional[FrameTimecode]:
|
| 121 |
+
"""Parse a timecode string into a FrameTimecode object.
|
| 122 |
+
Supports formats:
|
| 123 |
+
- Frames: '123'
|
| 124 |
+
- Seconds: '123s' or '123.45s'
|
| 125 |
+
- Timecode: '00:02:03' or '00:02:03.456'
|
| 126 |
+
Args:
|
| 127 |
+
video: Video object to get framerate from
|
| 128 |
+
time_str: String to parse, or None
|
| 129 |
+
Returns:
|
| 130 |
+
FrameTimecode object or None if input is None
|
| 131 |
+
"""
|
| 132 |
+
if time_str is None:
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
if time_str.endswith("s"):
|
| 137 |
+
# Seconds format
|
| 138 |
+
seconds = float(time_str[:-1])
|
| 139 |
+
return FrameTimecode(timecode=seconds, fps=video.frame_rate)
|
| 140 |
+
elif ":" in time_str:
|
| 141 |
+
# Timecode format
|
| 142 |
+
return FrameTimecode(timecode=time_str, fps=video.frame_rate)
|
| 143 |
+
else:
|
| 144 |
+
# Frame number format
|
| 145 |
+
return FrameTimecode(timecode=int(time_str), fps=video.frame_rate)
|
| 146 |
+
except ValueError as e:
|
| 147 |
+
raise typer.BadParameter(
|
| 148 |
+
f"Invalid timecode format: {time_str}. Use frames (123), "
|
| 149 |
+
f"seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn])",
|
| 150 |
+
) from e
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def detect_and_split_scenes( # noqa: PLR0913
|
| 154 |
+
video_path: str,
|
| 155 |
+
output_dir: Path,
|
| 156 |
+
detector_type: DetectorType,
|
| 157 |
+
threshold: Optional[float] = None,
|
| 158 |
+
min_scene_len: Optional[int] = None,
|
| 159 |
+
max_scenes: Optional[int] = None,
|
| 160 |
+
filter_shorter_than: Optional[str] = None,
|
| 161 |
+
skip_start: Optional[int] = None, # noqa: ARG001
|
| 162 |
+
skip_end: Optional[int] = None, # noqa: ARG001
|
| 163 |
+
save_images_per_scene: int = 0,
|
| 164 |
+
stats_file: Optional[str] = None,
|
| 165 |
+
luma_only: bool = False,
|
| 166 |
+
adaptive_window: Optional[int] = None,
|
| 167 |
+
fade_bias: Optional[float] = None,
|
| 168 |
+
downscale_factor: Optional[int] = None,
|
| 169 |
+
frame_skip: int = 0,
|
| 170 |
+
duration: Optional[str] = None,
|
| 171 |
+
) -> List[Tuple[FrameTimecode, FrameTimecode]]:
|
| 172 |
+
"""Detect and split scenes in a video using the specified parameters.
|
| 173 |
+
Args:
|
| 174 |
+
video_path: Path to input video.
|
| 175 |
+
output_dir: Directory to save output split scenes.
|
| 176 |
+
detector_type: Type of scene detector to use.
|
| 177 |
+
threshold: Detection threshold.
|
| 178 |
+
min_scene_len: Minimum scene length in frames.
|
| 179 |
+
max_scenes: Maximum number of scenes to detect.
|
| 180 |
+
filter_shorter_than: Filter out scenes shorter than this duration (frames/seconds/timecode)
|
| 181 |
+
skip_start: Number of frames to skip at start.
|
| 182 |
+
skip_end: Number of frames to skip at end.
|
| 183 |
+
save_images_per_scene: Number of images to save per scene (0 to disable).
|
| 184 |
+
stats_file: Path to save detection statistics (optional).
|
| 185 |
+
luma_only: Only use brightness for content detection.
|
| 186 |
+
adaptive_window: Window size for adaptive detection.
|
| 187 |
+
fade_bias: Bias for fade detection (-1.0 to 1.0).
|
| 188 |
+
downscale_factor: Factor to downscale frames by during detection.
|
| 189 |
+
frame_skip: Number of frames to skip (i.e. process every 1 in N+1 frames,
|
| 190 |
+
where N is frame_skip, processing only 1/N+1 percent of the video,
|
| 191 |
+
speeding up the detection time at the expense of accuracy).
|
| 192 |
+
frame_skip must be 0 (the default) when using a StatsManager.
|
| 193 |
+
duration: How much of the video to process from start position.
|
| 194 |
+
Can be specified as frames (123), seconds (123s/123.45s),
|
| 195 |
+
or timecode (HH:MM:SS[.nnn]).
|
| 196 |
+
Returns:
|
| 197 |
+
List of detected scenes as (start, end) FrameTimecode pairs.
|
| 198 |
+
"""
|
| 199 |
+
# Create video stream
|
| 200 |
+
video = open_video(video_path, backend="opencv")
|
| 201 |
+
|
| 202 |
+
# Parse duration if specified
|
| 203 |
+
duration_tc = parse_timecode(video, duration)
|
| 204 |
+
|
| 205 |
+
# Parse filter_shorter_than if specified
|
| 206 |
+
filter_shorter_than_tc = parse_timecode(video, filter_shorter_than)
|
| 207 |
+
|
| 208 |
+
# Initialize scene manager with optional stats manager
|
| 209 |
+
stats_manager = StatsManager() if stats_file else None
|
| 210 |
+
scene_manager = SceneManager(stats_manager)
|
| 211 |
+
|
| 212 |
+
# Configure scene manager
|
| 213 |
+
if downscale_factor:
|
| 214 |
+
scene_manager.auto_downscale = False
|
| 215 |
+
scene_manager.downscale = downscale_factor
|
| 216 |
+
|
| 217 |
+
# Create and add detector
|
| 218 |
+
detector = create_detector(
|
| 219 |
+
detector_type=detector_type,
|
| 220 |
+
threshold=threshold,
|
| 221 |
+
min_scene_len=min_scene_len,
|
| 222 |
+
luma_only=luma_only,
|
| 223 |
+
adaptive_window=adaptive_window,
|
| 224 |
+
fade_bias=fade_bias,
|
| 225 |
+
)
|
| 226 |
+
scene_manager.add_detector(detector)
|
| 227 |
+
|
| 228 |
+
# Detect scenes
|
| 229 |
+
typer.echo("Detecting scenes...")
|
| 230 |
+
scene_manager.detect_scenes(
|
| 231 |
+
video=video,
|
| 232 |
+
show_progress=True,
|
| 233 |
+
frame_skip=frame_skip,
|
| 234 |
+
duration=duration_tc,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Get scene list
|
| 238 |
+
scenes = scene_manager.get_scene_list()
|
| 239 |
+
|
| 240 |
+
# Filter out scenes that are too short if filter_shorter_than is specified
|
| 241 |
+
if filter_shorter_than_tc:
|
| 242 |
+
original_count = len(scenes)
|
| 243 |
+
scenes = [
|
| 244 |
+
(start, end)
|
| 245 |
+
for start, end in scenes
|
| 246 |
+
if (end.get_frames() - start.get_frames()) >= filter_shorter_than_tc.get_frames()
|
| 247 |
+
]
|
| 248 |
+
if len(scenes) < original_count:
|
| 249 |
+
typer.echo(
|
| 250 |
+
f"Filtered out {original_count - len(scenes)} scenes shorter "
|
| 251 |
+
f"than {filter_shorter_than_tc.get_seconds():.1f} seconds "
|
| 252 |
+
f"({filter_shorter_than_tc.get_frames()} frames)",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Apply max scenes limit if specified
|
| 256 |
+
if max_scenes and len(scenes) > max_scenes:
|
| 257 |
+
typer.echo(f"Dropping last {len(scenes) - max_scenes} scenes to meet max_scenes ({max_scenes}) limit")
|
| 258 |
+
scenes = scenes[:max_scenes]
|
| 259 |
+
|
| 260 |
+
# Print scene information
|
| 261 |
+
typer.echo(f"Found {len(scenes)} scenes:")
|
| 262 |
+
for i, (start, end) in enumerate(scenes, 1):
|
| 263 |
+
typer.echo(
|
| 264 |
+
f"Scene {i}: {start.get_timecode()} to {end.get_timecode()} "
|
| 265 |
+
f"({end.get_frames() - start.get_frames()} frames)",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Save stats if requested
|
| 269 |
+
if stats_file:
|
| 270 |
+
typer.echo(f"Saving detection stats to {stats_file}")
|
| 271 |
+
stats_manager.save_to_csv(stats_file)
|
| 272 |
+
|
| 273 |
+
# Split video into scenes
|
| 274 |
+
typer.echo("Splitting video into scenes...")
|
| 275 |
+
try:
|
| 276 |
+
split_video_ffmpeg(
|
| 277 |
+
input_video_path=video_path,
|
| 278 |
+
scene_list=scenes,
|
| 279 |
+
output_dir=output_dir,
|
| 280 |
+
show_progress=True,
|
| 281 |
+
)
|
| 282 |
+
typer.echo(f"Scenes have been saved to: {output_dir}")
|
| 283 |
+
except Exception as e:
|
| 284 |
+
raise typer.BadParameter(f"Error splitting video: {e}") from e
|
| 285 |
+
|
| 286 |
+
# Save preview images if requested
|
| 287 |
+
if save_images_per_scene > 0:
|
| 288 |
+
typer.echo(f"Saving {save_images_per_scene} preview images per scene...")
|
| 289 |
+
image_filenames = save_scene_images(
|
| 290 |
+
scene_list=scenes,
|
| 291 |
+
video=video,
|
| 292 |
+
num_images=save_images_per_scene,
|
| 293 |
+
output_dir=str(output_dir),
|
| 294 |
+
show_progress=True,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Generate HTML report with scene information and previews
|
| 298 |
+
html_path = output_dir / "scene_report.html"
|
| 299 |
+
write_scene_list_html(
|
| 300 |
+
output_html_filename=str(html_path),
|
| 301 |
+
scene_list=scenes,
|
| 302 |
+
image_filenames=image_filenames,
|
| 303 |
+
)
|
| 304 |
+
typer.echo(f"Scene report saved to: {html_path}")
|
| 305 |
+
|
| 306 |
+
return scenes
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@app.command()
|
| 310 |
+
def main( # noqa: PLR0913
|
| 311 |
+
video_path: Path = typer.Argument( # noqa: B008
|
| 312 |
+
...,
|
| 313 |
+
help="Path to the input video file",
|
| 314 |
+
exists=True,
|
| 315 |
+
dir_okay=False,
|
| 316 |
+
),
|
| 317 |
+
output_dir: str = typer.Argument(
|
| 318 |
+
...,
|
| 319 |
+
help="Directory where split scenes will be saved",
|
| 320 |
+
),
|
| 321 |
+
detector: DetectorType = typer.Option( # noqa: B008
|
| 322 |
+
DetectorType.CONTENT,
|
| 323 |
+
help="Scene detection algorithm to use",
|
| 324 |
+
),
|
| 325 |
+
threshold: Optional[float] = typer.Option(
|
| 326 |
+
None,
|
| 327 |
+
help="Detection threshold (meaning varies by detector)",
|
| 328 |
+
),
|
| 329 |
+
max_scenes: Optional[int] = typer.Option(
|
| 330 |
+
None,
|
| 331 |
+
help="Maximum number of scenes to produce",
|
| 332 |
+
),
|
| 333 |
+
min_scene_length: Optional[int] = typer.Option(
|
| 334 |
+
None,
|
| 335 |
+
help="Minimum scene length during detection. Forces the detector to make scenes at least this many frames. "
|
| 336 |
+
"This affects scene detection behavior but does not filter out short scenes.",
|
| 337 |
+
),
|
| 338 |
+
filter_shorter_than: Optional[str] = typer.Option(
|
| 339 |
+
None,
|
| 340 |
+
help="Filter out scenes shorter than this duration. Can be specified as frames (123), "
|
| 341 |
+
"seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn]). These scenes will be detected but not saved.",
|
| 342 |
+
),
|
| 343 |
+
skip_start: Optional[int] = typer.Option(
|
| 344 |
+
None,
|
| 345 |
+
help="Number of frames to skip at the start of the video",
|
| 346 |
+
),
|
| 347 |
+
skip_end: Optional[int] = typer.Option(
|
| 348 |
+
None,
|
| 349 |
+
help="Number of frames to skip at the end of the video",
|
| 350 |
+
),
|
| 351 |
+
duration: Optional[str] = typer.Option(
|
| 352 |
+
None,
|
| 353 |
+
"-d",
|
| 354 |
+
help="How much of the video to process. Can be specified as frames (123), "
|
| 355 |
+
"seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn])",
|
| 356 |
+
),
|
| 357 |
+
save_images: int = typer.Option(
|
| 358 |
+
0,
|
| 359 |
+
help="Number of preview images to save per scene (0 to disable)",
|
| 360 |
+
),
|
| 361 |
+
stats_file: Optional[str] = typer.Option(
|
| 362 |
+
None,
|
| 363 |
+
help="Path to save detection statistics CSV",
|
| 364 |
+
),
|
| 365 |
+
luma_only: bool = typer.Option(
|
| 366 |
+
False,
|
| 367 |
+
help="Only use brightness for content detection",
|
| 368 |
+
),
|
| 369 |
+
adaptive_window: Optional[int] = typer.Option(
|
| 370 |
+
None,
|
| 371 |
+
help="Window size for adaptive detection",
|
| 372 |
+
),
|
| 373 |
+
fade_bias: Optional[float] = typer.Option(
|
| 374 |
+
None,
|
| 375 |
+
help="Bias for fade detection (-1.0 to 1.0)",
|
| 376 |
+
),
|
| 377 |
+
downscale: Optional[int] = typer.Option(
|
| 378 |
+
None,
|
| 379 |
+
help="Factor to downscale frames by during detection",
|
| 380 |
+
),
|
| 381 |
+
frame_skip: int = typer.Option(
|
| 382 |
+
0,
|
| 383 |
+
help="Number of frames to skip during processing",
|
| 384 |
+
),
|
| 385 |
+
) -> None:
|
| 386 |
+
"""Split video into scenes using PySceneDetect."""
|
| 387 |
+
if skip_start or skip_end:
|
| 388 |
+
typer.echo("Skipping start and end frames is not supported yet.")
|
| 389 |
+
return
|
| 390 |
+
|
| 391 |
+
# Validate output directory
|
| 392 |
+
output_path = validate_output_dir(output_dir)
|
| 393 |
+
|
| 394 |
+
# Detect and split scenes
|
| 395 |
+
detect_and_split_scenes(
|
| 396 |
+
video_path=str(video_path),
|
| 397 |
+
output_dir=output_path,
|
| 398 |
+
detector_type=detector,
|
| 399 |
+
threshold=threshold,
|
| 400 |
+
min_scene_len=min_scene_length,
|
| 401 |
+
max_scenes=max_scenes,
|
| 402 |
+
filter_shorter_than=filter_shorter_than,
|
| 403 |
+
skip_start=skip_start,
|
| 404 |
+
skip_end=skip_end,
|
| 405 |
+
duration=duration,
|
| 406 |
+
save_images_per_scene=save_images,
|
| 407 |
+
stats_file=stats_file,
|
| 408 |
+
luma_only=luma_only,
|
| 409 |
+
adaptive_window=adaptive_window,
|
| 410 |
+
fade_bias=fade_bias,
|
| 411 |
+
downscale_factor=downscale,
|
| 412 |
+
frame_skip=frame_skip,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
if __name__ == "__main__":
|
| 417 |
+
app()
|
packages/ltx-trainer/scripts/train.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Train LTXV models using configuration from YAML files.
|
| 5 |
+
This script provides a command-line interface for training LTXV models using
|
| 6 |
+
either LoRA fine-tuning or full model fine-tuning. It loads configuration from
|
| 7 |
+
a YAML file and passes it to the trainer.
|
| 8 |
+
Basic usage:
|
| 9 |
+
python scripts/train.py CONFIG_PATH [--disable-progress-bars]
|
| 10 |
+
For multi-GPU/FSDP training, configure and launch via Accelerate:
|
| 11 |
+
accelerate config
|
| 12 |
+
accelerate launch scripts/train.py CONFIG_PATH
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import typer
|
| 18 |
+
import yaml
|
| 19 |
+
from rich.console import Console
|
| 20 |
+
|
| 21 |
+
from ltx_trainer.config import LtxTrainerConfig
|
| 22 |
+
from ltx_trainer.trainer import LtxvTrainer
|
| 23 |
+
|
| 24 |
+
console = Console()
|
| 25 |
+
app = typer.Typer(
|
| 26 |
+
pretty_exceptions_enable=False,
|
| 27 |
+
no_args_is_help=True,
|
| 28 |
+
help="Train LTXV models using configuration from YAML files.",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@app.command()
|
| 33 |
+
def main(
|
| 34 |
+
config_path: str = typer.Argument(..., help="Path to YAML configuration file"),
|
| 35 |
+
disable_progress_bars: bool = typer.Option(
|
| 36 |
+
False,
|
| 37 |
+
"--disable-progress-bars",
|
| 38 |
+
help="Disable progress bars (useful for multi-process runs)",
|
| 39 |
+
),
|
| 40 |
+
) -> None:
|
| 41 |
+
"""Train the model using the provided configuration file."""
|
| 42 |
+
# Load the configuration from the YAML file
|
| 43 |
+
config_path = Path(config_path)
|
| 44 |
+
if not config_path.exists():
|
| 45 |
+
typer.echo(f"Error: Configuration file {config_path} does not exist.")
|
| 46 |
+
raise typer.Exit(code=1)
|
| 47 |
+
|
| 48 |
+
with open(config_path, "r") as file:
|
| 49 |
+
config_data = yaml.safe_load(file)
|
| 50 |
+
|
| 51 |
+
# Convert the loaded data to the LtxTrainerConfig object
|
| 52 |
+
try:
|
| 53 |
+
trainer_config = LtxTrainerConfig(**config_data)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
typer.echo(f"Error: Invalid configuration data: {e}")
|
| 56 |
+
raise typer.Exit(code=1) from e
|
| 57 |
+
|
| 58 |
+
# Initialize the training process
|
| 59 |
+
trainer = LtxvTrainer(trainer_config)
|
| 60 |
+
trainer.train(disable_progress_bars=disable_progress_bars)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
app()
|
packages/ltx-trainer/src/ltx_trainer/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
packages/ltx-trainer/src/ltx_trainer/__pycache__/model_loader.cpython-312.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
packages/ltx-trainer/src/ltx_trainer/captioning.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio-visual media captioning using multimodal models.
|
| 3 |
+
This module provides captioning capabilities for videos with audio using:
|
| 4 |
+
- Qwen2.5-Omni: Local model supporting text, audio, image, and video inputs (default)
|
| 5 |
+
- Gemini Flash: Cloud-based API for audio-visual captioning
|
| 6 |
+
Requirements:
|
| 7 |
+
- Qwen2.5-Omni: transformers>=4.50, torch
|
| 8 |
+
- Gemini Flash: google-generativeai (uv pip install google-generativeai)
|
| 9 |
+
Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import itertools
|
| 13 |
+
import re
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from enum import Enum
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
# Instruction for audio-visual captioning (default) - includes speech transcription and sounds
|
| 21 |
+
DEFAULT_CAPTION_INSTRUCTION = """\
|
| 22 |
+
Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections:
|
| 23 |
+
|
| 24 |
+
[VISUAL]: <Detailed description of people, objects, actions, settings, colors, and movements>
|
| 25 |
+
[SPEECH]: <Word-for-word transcription of everything spoken.
|
| 26 |
+
Listen carefully and transcribe the exact words. If no speech, write "None">
|
| 27 |
+
[SOUNDS]: <Description of music, ambient sounds, sound effects. If none, write "None">
|
| 28 |
+
[TEXT]: <Any on-screen text visible. If none, write "None">
|
| 29 |
+
|
| 30 |
+
You MUST fill in all four sections. For [SPEECH], transcribe the actual words spoken, not a summary."""
|
| 31 |
+
|
| 32 |
+
# Instruction for video-only captioning (no audio processing)
|
| 33 |
+
VIDEO_ONLY_CAPTION_INSTRUCTION = """\
|
| 34 |
+
Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections:
|
| 35 |
+
|
| 36 |
+
[VISUAL]: <Detailed description of people, objects, actions, settings, colors, and movements>
|
| 37 |
+
[TEXT]: <Any on-screen text visible. If none, write "None">
|
| 38 |
+
|
| 39 |
+
You MUST fill in both sections."""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CaptionerType(str, Enum):
|
| 43 |
+
"""Enum for different types of media captioners."""
|
| 44 |
+
|
| 45 |
+
QWEN_OMNI = "qwen_omni" # Local Qwen2.5-Omni model (audio + video)
|
| 46 |
+
GEMINI_FLASH = "gemini_flash" # Gemini Flash API (audio + video)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptioningModel":
|
| 50 |
+
"""Factory function to create a media captioner.
|
| 51 |
+
Args:
|
| 52 |
+
captioner_type: The type of captioner to create
|
| 53 |
+
**kwargs: Additional arguments to pass to the captioner constructor
|
| 54 |
+
Returns:
|
| 55 |
+
An instance of a MediaCaptioningModel
|
| 56 |
+
"""
|
| 57 |
+
match captioner_type:
|
| 58 |
+
case CaptionerType.QWEN_OMNI:
|
| 59 |
+
return QwenOmniCaptioner(**kwargs)
|
| 60 |
+
case CaptionerType.GEMINI_FLASH:
|
| 61 |
+
return GeminiFlashCaptioner(**kwargs)
|
| 62 |
+
case _:
|
| 63 |
+
raise ValueError(f"Unsupported captioner type: {captioner_type}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class MediaCaptioningModel(ABC):
|
| 67 |
+
"""Abstract base class for audio-visual media captioning models."""
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def caption(self, path: str | Path, **kwargs) -> str:
|
| 71 |
+
"""Generate a caption for the given video or image.
|
| 72 |
+
Args:
|
| 73 |
+
path: Path to the video/image file to caption
|
| 74 |
+
Returns:
|
| 75 |
+
A string containing the generated caption
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
@abstractmethod
|
| 80 |
+
def supports_audio(self) -> bool:
|
| 81 |
+
"""Whether this captioner supports audio input."""
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def _is_image_file(path: str | Path) -> bool:
|
| 85 |
+
"""Check if the file is an image based on extension."""
|
| 86 |
+
return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".heic", ".heif", ".webp"))
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def _is_video_file(path: str | Path) -> bool:
|
| 90 |
+
"""Check if the file is a video based on extension."""
|
| 91 |
+
return str(path).lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def _clean_raw_caption(caption: str) -> str:
|
| 95 |
+
"""Clean up the raw caption by removing common VLM patterns."""
|
| 96 |
+
start = ["The", "This"]
|
| 97 |
+
kind = ["video", "image", "scene", "animated sequence", "clip", "footage"]
|
| 98 |
+
act = ["displays", "shows", "features", "depicts", "presents", "showcases", "captures", "contains"]
|
| 99 |
+
|
| 100 |
+
for x, y, z in itertools.product(start, kind, act):
|
| 101 |
+
caption = caption.replace(f"{x} {y} {z} ", "", 1)
|
| 102 |
+
|
| 103 |
+
return caption
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class QwenOmniCaptioner(MediaCaptioningModel):
|
| 107 |
+
"""Audio-visual captioning using Alibaba's Qwen2.5-Omni model.
|
| 108 |
+
Qwen2.5-Omni is an end-to-end multimodal model that can perceive text, images, audio, and video.
|
| 109 |
+
It uses a Thinker-Talker architecture where the Thinker generates text and the Talker can
|
| 110 |
+
generate speech. For captioning, we use only the Thinker component for text generation.
|
| 111 |
+
Key features:
|
| 112 |
+
- Block-wise processing for streaming multimodal inputs
|
| 113 |
+
- TMRoPE (Time-aligned Multimodal RoPE) for synchronizing video and audio timestamps
|
| 114 |
+
- Can extract and process audio directly from video files
|
| 115 |
+
See: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni
|
| 116 |
+
Model: Qwen/Qwen2.5-Omni-7B (7B parameters)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
| 120 |
+
|
| 121 |
+
# Default system prompt required by Qwen2.5-Omni for proper audio processing
|
| 122 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 123 |
+
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
|
| 124 |
+
"capable of perceiving auditory and visual inputs, as well as generating text and speech."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
device: str | torch.device | None = None,
|
| 130 |
+
use_8bit: bool = False,
|
| 131 |
+
instruction: str | None = None,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Initialize the Qwen2.5-Omni captioner.
|
| 135 |
+
Args:
|
| 136 |
+
device: Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu')
|
| 137 |
+
use_8bit: Whether to use 8-bit quantization for reduced memory usage
|
| 138 |
+
instruction: Custom instruction prompt. If None, uses the default instruction
|
| 139 |
+
"""
|
| 140 |
+
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 141 |
+
self.instruction = instruction
|
| 142 |
+
self._load_model(use_8bit=use_8bit)
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def supports_audio(self) -> bool:
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
def caption(
|
| 149 |
+
self,
|
| 150 |
+
path: str | Path,
|
| 151 |
+
fps: int = 1,
|
| 152 |
+
include_audio: bool = True,
|
| 153 |
+
clean_caption: bool = True,
|
| 154 |
+
) -> str:
|
| 155 |
+
"""Generate a caption for the given video or image.
|
| 156 |
+
Args:
|
| 157 |
+
path: Path to the video/image file to caption
|
| 158 |
+
fps: Frames per second to sample from videos
|
| 159 |
+
include_audio: Whether to include audio in the captioning (for videos)
|
| 160 |
+
clean_caption: Whether to clean up the raw caption by removing common VLM patterns
|
| 161 |
+
Returns:
|
| 162 |
+
A string containing the generated caption
|
| 163 |
+
"""
|
| 164 |
+
path = Path(path)
|
| 165 |
+
is_image = self._is_image_file(path)
|
| 166 |
+
is_video = self._is_video_file(path)
|
| 167 |
+
|
| 168 |
+
# Determine if we should process audio
|
| 169 |
+
use_audio = include_audio and is_video
|
| 170 |
+
|
| 171 |
+
# Use custom instruction if provided, otherwise pick appropriate default
|
| 172 |
+
if self.instruction is not None:
|
| 173 |
+
instruction = self.instruction
|
| 174 |
+
else:
|
| 175 |
+
instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION
|
| 176 |
+
|
| 177 |
+
# Build the user content based on media type
|
| 178 |
+
# Based on HuggingFace docs: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni
|
| 179 |
+
user_content = []
|
| 180 |
+
|
| 181 |
+
if is_image:
|
| 182 |
+
user_content.append({"type": "image", "image": str(path)})
|
| 183 |
+
elif is_video:
|
| 184 |
+
user_content.append({"type": "video", "video": str(path)})
|
| 185 |
+
|
| 186 |
+
# Add the instruction text
|
| 187 |
+
user_content.append({"type": "text", "text": instruction})
|
| 188 |
+
|
| 189 |
+
# Build conversation - use the default system prompt required by Qwen2.5-Omni
|
| 190 |
+
# Using a custom system prompt causes warnings and may affect audio processing
|
| 191 |
+
messages = [
|
| 192 |
+
{
|
| 193 |
+
"role": "system",
|
| 194 |
+
"content": [{"type": "text", "text": self.DEFAULT_SYSTEM_PROMPT}],
|
| 195 |
+
},
|
| 196 |
+
{"role": "user", "content": user_content},
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
# Process inputs using the processor's apply_chat_template
|
| 200 |
+
# For videos with audio, use load_audio_from_video=True and use_audio_in_video=True
|
| 201 |
+
inputs = self.processor.apply_chat_template(
|
| 202 |
+
messages,
|
| 203 |
+
load_audio_from_video=use_audio,
|
| 204 |
+
add_generation_prompt=True,
|
| 205 |
+
tokenize=True,
|
| 206 |
+
return_dict=True,
|
| 207 |
+
return_tensors="pt",
|
| 208 |
+
fps=fps,
|
| 209 |
+
padding=True,
|
| 210 |
+
use_audio_in_video=use_audio,
|
| 211 |
+
).to(self.model.device)
|
| 212 |
+
|
| 213 |
+
# Generate caption (text only, using Thinker-only model)
|
| 214 |
+
# Note: For Qwen2_5OmniThinkerForConditionalGeneration, use standard generate params
|
| 215 |
+
# (not thinker_ prefixed ones, those are for the full Qwen2_5OmniForConditionalGeneration)
|
| 216 |
+
input_len = inputs["input_ids"].shape[1]
|
| 217 |
+
|
| 218 |
+
output_tokens = self.model.generate(
|
| 219 |
+
**inputs,
|
| 220 |
+
use_audio_in_video=use_audio,
|
| 221 |
+
do_sample=False,
|
| 222 |
+
max_new_tokens=1024,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Extract only the generated tokens (exclude the input/prompt tokens)
|
| 226 |
+
generated_tokens = output_tokens[:, input_len:]
|
| 227 |
+
|
| 228 |
+
# Decode only the generated response
|
| 229 |
+
caption_raw = self.processor.batch_decode(
|
| 230 |
+
generated_tokens,
|
| 231 |
+
skip_special_tokens=True,
|
| 232 |
+
clean_up_tokenization_spaces=False,
|
| 233 |
+
)[0]
|
| 234 |
+
|
| 235 |
+
# Remove hallucinated conversation turns (e.g., "Human\nHuman\n..." or "Human: ...")
|
| 236 |
+
# This is a known issue with chat models continuing to generate fake turns
|
| 237 |
+
# We look for patterns that are clearly hallucinated chat turns, not legitimate uses of "human"
|
| 238 |
+
|
| 239 |
+
# Match "\nHuman" followed by ":", "\n", or end of string (chat turn patterns)
|
| 240 |
+
# This won't match "A human walks..." or "...the human body..."
|
| 241 |
+
caption_raw = re.split(r"\nHuman(?::|(?:\s*\n)|$)", caption_raw, maxsplit=1)[0]
|
| 242 |
+
caption_raw = caption_raw.strip()
|
| 243 |
+
|
| 244 |
+
# Clean up caption if requested
|
| 245 |
+
return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
|
| 246 |
+
|
| 247 |
+
def _load_model(self, use_8bit: bool) -> None:
|
| 248 |
+
"""Load the Qwen2.5-Omni model and processor.
|
| 249 |
+
Uses the Thinker-only model (Qwen2_5OmniThinkerForConditionalGeneration) for text generation
|
| 250 |
+
to save compute by not loading the audio generation components.
|
| 251 |
+
"""
|
| 252 |
+
from transformers import ( # noqa: PLC0415
|
| 253 |
+
BitsAndBytesConfig,
|
| 254 |
+
Qwen2_5OmniProcessor,
|
| 255 |
+
Qwen2_5OmniThinkerForConditionalGeneration,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True) if use_8bit else None
|
| 259 |
+
|
| 260 |
+
# Use Thinker-only model for text generation (saves memory by not loading Talker)
|
| 261 |
+
self.model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
| 262 |
+
self.MODEL_ID,
|
| 263 |
+
dtype=torch.bfloat16,
|
| 264 |
+
low_cpu_mem_usage=True,
|
| 265 |
+
quantization_config=quantization_config,
|
| 266 |
+
device_map="auto",
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
self.processor = Qwen2_5OmniProcessor.from_pretrained(self.MODEL_ID)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class GeminiFlashCaptioner(MediaCaptioningModel):
|
| 273 |
+
"""Audio-visual captioning using Google's Gemini Flash API.
|
| 274 |
+
Gemini Flash is a cloud-based multimodal model that natively supports
|
| 275 |
+
audio and video understanding. Requires a Google API key.
|
| 276 |
+
Note: This captioner requires the `google-generativeai` package and a valid API key.
|
| 277 |
+
Set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable, or pass the key directly.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
MODEL_ID = "gemini-flash-lite-latest"
|
| 281 |
+
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
api_key: str | None = None,
|
| 285 |
+
instruction: str | None = None,
|
| 286 |
+
):
|
| 287 |
+
"""Initialize the Gemini Flash captioner.
|
| 288 |
+
Args:
|
| 289 |
+
api_key: Google API key. If not provided, will look for
|
| 290 |
+
GEMINI_API_KEY or GOOGLE_API_KEY environment variable.
|
| 291 |
+
instruction: Custom instruction prompt. If None, uses the default instruction
|
| 292 |
+
"""
|
| 293 |
+
self.instruction = instruction
|
| 294 |
+
self._init_client(api_key)
|
| 295 |
+
|
| 296 |
+
@property
|
| 297 |
+
def supports_audio(self) -> bool:
|
| 298 |
+
return True
|
| 299 |
+
|
| 300 |
+
def caption(
|
| 301 |
+
self,
|
| 302 |
+
path: str | Path,
|
| 303 |
+
fps: int = 3, # noqa: ARG002 - kept for API compatibility
|
| 304 |
+
include_audio: bool = True,
|
| 305 |
+
clean_caption: bool = True,
|
| 306 |
+
) -> str:
|
| 307 |
+
"""Generate a caption for the given video or image.
|
| 308 |
+
Args:
|
| 309 |
+
path: Path to the video/image file to caption
|
| 310 |
+
fps: Frames per second (not used for Gemini, kept for API compatibility)
|
| 311 |
+
include_audio: Whether to include audio content in the caption
|
| 312 |
+
clean_caption: Whether to clean up the raw caption
|
| 313 |
+
Returns:
|
| 314 |
+
A string containing the generated caption
|
| 315 |
+
"""
|
| 316 |
+
import time # noqa: PLC0415
|
| 317 |
+
|
| 318 |
+
path = Path(path)
|
| 319 |
+
is_video = self._is_video_file(path)
|
| 320 |
+
use_audio = include_audio and is_video
|
| 321 |
+
|
| 322 |
+
# Use custom instruction if provided, otherwise pick appropriate default
|
| 323 |
+
if self.instruction is not None:
|
| 324 |
+
instruction = self.instruction
|
| 325 |
+
else:
|
| 326 |
+
instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION
|
| 327 |
+
|
| 328 |
+
# Upload the file to Gemini
|
| 329 |
+
uploaded_file = self._genai.upload_file(path)
|
| 330 |
+
|
| 331 |
+
# Wait for processing to complete (videos need time to process)
|
| 332 |
+
while uploaded_file.state.name == "PROCESSING":
|
| 333 |
+
time.sleep(1)
|
| 334 |
+
uploaded_file = self._genai.get_file(uploaded_file.name)
|
| 335 |
+
|
| 336 |
+
if uploaded_file.state.name == "FAILED":
|
| 337 |
+
raise RuntimeError(f"File processing failed: {uploaded_file.state.name}")
|
| 338 |
+
|
| 339 |
+
# Generate caption
|
| 340 |
+
response = self._model.generate_content([uploaded_file, instruction])
|
| 341 |
+
|
| 342 |
+
caption_raw = response.text
|
| 343 |
+
|
| 344 |
+
# Clean up the uploaded file
|
| 345 |
+
self._genai.delete_file(uploaded_file.name)
|
| 346 |
+
|
| 347 |
+
# Clean up caption if requested
|
| 348 |
+
return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
|
| 349 |
+
|
| 350 |
+
def _init_client(self, api_key: str | None) -> None:
|
| 351 |
+
"""Initialize the Gemini API client."""
|
| 352 |
+
import os # noqa: PLC0415
|
| 353 |
+
|
| 354 |
+
try:
|
| 355 |
+
import google.generativeai as genai # noqa: PLC0415
|
| 356 |
+
except ImportError as e:
|
| 357 |
+
raise ImportError(
|
| 358 |
+
"The `google-generativeai` package is required for Gemini Flash captioning. "
|
| 359 |
+
"Install it with: `uv pip install google-generativeai`"
|
| 360 |
+
) from e
|
| 361 |
+
|
| 362 |
+
# Get API key from argument or environment
|
| 363 |
+
# GEMINI_API_KEY is the recommended variable, GOOGLE_API_KEY also works
|
| 364 |
+
resolved_api_key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
|
| 365 |
+
|
| 366 |
+
if not resolved_api_key:
|
| 367 |
+
raise ValueError(
|
| 368 |
+
"Gemini API key is required. Provide it via the `api_key` argument "
|
| 369 |
+
"or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Configure the genai library with the API key
|
| 373 |
+
genai.configure(api_key=resolved_api_key)
|
| 374 |
+
|
| 375 |
+
# Store reference to genai module for file operations
|
| 376 |
+
self._genai = genai
|
| 377 |
+
|
| 378 |
+
# Initialize the model
|
| 379 |
+
self._model = genai.GenerativeModel(self.MODEL_ID)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def example() -> None:
|
| 383 |
+
"""Example usage of the captioning module."""
|
| 384 |
+
import sys # noqa: PLC0415
|
| 385 |
+
|
| 386 |
+
if len(sys.argv) < 2:
|
| 387 |
+
print(f"Usage: python {sys.argv[0]} <video_path> [captioner_type]") # noqa: T201
|
| 388 |
+
print(" captioner_type: qwen_omni (default) or gemini_flash") # noqa: T201
|
| 389 |
+
sys.exit(1)
|
| 390 |
+
|
| 391 |
+
video_path = sys.argv[1]
|
| 392 |
+
captioner_type = CaptionerType(sys.argv[2]) if len(sys.argv) > 2 else CaptionerType.QWEN_OMNI
|
| 393 |
+
|
| 394 |
+
print(f"Using {captioner_type.value} captioner:") # noqa: T201
|
| 395 |
+
captioner = create_captioner(captioner_type)
|
| 396 |
+
caption = captioner.caption(video_path)
|
| 397 |
+
print(f"CAPTION: {caption}") # noqa: T201
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
example()
|
packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa: PLC0415
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
8-bit Gemma text encoder loading utilities.
|
| 5 |
+
This module provides functionality for loading the Gemma text encoder in 8-bit precision
|
| 6 |
+
using bitsandbytes, which significantly reduces GPU memory usage.
|
| 7 |
+
Example usage:
|
| 8 |
+
from ltx_trainer.gemma_8bit import load_8bit_gemma
|
| 9 |
+
text_encoder = load_8bit_gemma(gemma_model_path="/path/to/gemma")
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
from collections.abc import Generator
|
| 16 |
+
from contextlib import contextmanager
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoder
|
| 22 |
+
from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_8bit_gemma(gemma_model_path: str | Path, dtype: torch.dtype = torch.bfloat16) -> GemmaTextEncoder:
|
| 26 |
+
"""Load the Gemma text encoder in 8-bit precision using bitsandbytes.
|
| 27 |
+
Only the Gemma LLM backbone is loaded here. The embeddings processor
|
| 28 |
+
(feature extractor + connectors) should be loaded separately via
|
| 29 |
+
:func:`ltx_trainer.model_loader.load_embeddings_processor`.
|
| 30 |
+
Args:
|
| 31 |
+
gemma_model_path: Path to Gemma model directory
|
| 32 |
+
dtype: Data type for non-quantized model weights
|
| 33 |
+
Returns:
|
| 34 |
+
GemmaTextEncoder with 8-bit quantized Gemma backbone
|
| 35 |
+
Raises:
|
| 36 |
+
ImportError: If bitsandbytes is not installed
|
| 37 |
+
FileNotFoundError: If required model files are not found
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
from transformers import BitsAndBytesConfig, Gemma3ForConditionalGeneration
|
| 41 |
+
except ImportError as e:
|
| 42 |
+
raise ImportError(
|
| 43 |
+
"8-bit text encoder loading requires bitsandbytes. Install it with: uv pip install bitsandbytes"
|
| 44 |
+
) from e
|
| 45 |
+
|
| 46 |
+
gemma_path = _find_gemma_subpath(gemma_model_path, "model*.safetensors")
|
| 47 |
+
tokenizer_path = _find_gemma_subpath(gemma_model_path, "tokenizer.model")
|
| 48 |
+
|
| 49 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 50 |
+
with _suppress_accelerate_memory_warnings():
|
| 51 |
+
gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 52 |
+
gemma_path,
|
| 53 |
+
quantization_config=quantization_config,
|
| 54 |
+
torch_dtype=torch.bfloat16,
|
| 55 |
+
device_map="auto",
|
| 56 |
+
local_files_only=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
tokenizer = LTXVGemmaTokenizer(tokenizer_path, 1024)
|
| 60 |
+
|
| 61 |
+
return GemmaTextEncoder(
|
| 62 |
+
tokenizer=tokenizer,
|
| 63 |
+
model=gemma_model,
|
| 64 |
+
dtype=dtype,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _find_gemma_subpath(root_path: str | Path, pattern: str) -> str:
|
| 69 |
+
"""Find a file matching a glob pattern and return its parent directory."""
|
| 70 |
+
matches = list(Path(root_path).rglob(pattern))
|
| 71 |
+
if not matches:
|
| 72 |
+
raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}")
|
| 73 |
+
return str(matches[0].parent)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@contextmanager
|
| 77 |
+
def _suppress_accelerate_memory_warnings() -> Generator[None, None, None]:
|
| 78 |
+
"""Temporarily suppress INFO warnings from accelerate about memory allocation."""
|
| 79 |
+
accelerate_logger = logging.getLogger("accelerate.utils.modeling")
|
| 80 |
+
old_level = accelerate_logger.level
|
| 81 |
+
accelerate_logger.setLevel(logging.WARNING)
|
| 82 |
+
try:
|
| 83 |
+
yield
|
| 84 |
+
finally:
|
| 85 |
+
accelerate_logger.setLevel(old_level)
|
packages/ltx-trainer/src/ltx_trainer/gpu_utils.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GPU memory management utilities for training and inference."""
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import gc
|
| 5 |
+
import subprocess
|
| 6 |
+
from typing import Callable, TypeVar
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from ltx_trainer import logger
|
| 11 |
+
|
| 12 |
+
F = TypeVar("F", bound=Callable)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def free_gpu_memory(log: bool = False) -> None:
|
| 16 |
+
"""Free GPU memory by running garbage collection and emptying CUDA cache.
|
| 17 |
+
Args:
|
| 18 |
+
log: If True, log memory stats after clearing
|
| 19 |
+
"""
|
| 20 |
+
gc.collect()
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
torch.cuda.empty_cache()
|
| 23 |
+
if log:
|
| 24 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 25 |
+
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 26 |
+
logger.debug(f"GPU memory freed. Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class free_gpu_memory_context: # noqa: N801
|
| 30 |
+
"""Context manager and decorator to free GPU memory before and/or after execution.
|
| 31 |
+
Can be used as a decorator:
|
| 32 |
+
@free_gpu_memory_context(after=True)
|
| 33 |
+
def my_function():
|
| 34 |
+
...
|
| 35 |
+
Or as a context manager:
|
| 36 |
+
with free_gpu_memory_context():
|
| 37 |
+
heavy_operation()
|
| 38 |
+
Args:
|
| 39 |
+
before: Free memory before execution (default: False)
|
| 40 |
+
after: Free memory after execution (default: True)
|
| 41 |
+
log: Log memory stats when freeing (default: False)
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, *, before: bool = False, after: bool = True, log: bool = False) -> None:
|
| 45 |
+
self.before = before
|
| 46 |
+
self.after = after
|
| 47 |
+
self.log = log
|
| 48 |
+
|
| 49 |
+
def __enter__(self) -> "free_gpu_memory_context":
|
| 50 |
+
if self.before:
|
| 51 |
+
free_gpu_memory(log=self.log)
|
| 52 |
+
return self
|
| 53 |
+
|
| 54 |
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: object) -> None:
|
| 55 |
+
if self.after:
|
| 56 |
+
free_gpu_memory(log=self.log)
|
| 57 |
+
|
| 58 |
+
def __call__(self, func: F) -> F:
|
| 59 |
+
@functools.wraps(func)
|
| 60 |
+
def wrapper(*args, **kwargs) -> object:
|
| 61 |
+
with self:
|
| 62 |
+
return func(*args, **kwargs)
|
| 63 |
+
|
| 64 |
+
return wrapper # type: ignore
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_gpu_memory_gb(device: torch.device) -> float:
|
| 68 |
+
"""Get current GPU memory usage in GB using nvidia-smi.
|
| 69 |
+
Args:
|
| 70 |
+
device: torch.device to get memory usage for
|
| 71 |
+
Returns:
|
| 72 |
+
Current GPU memory usage in GB
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
device_id = device.index if device.index is not None else 0
|
| 76 |
+
result = subprocess.check_output(
|
| 77 |
+
[
|
| 78 |
+
"nvidia-smi",
|
| 79 |
+
"--query-gpu=memory.used",
|
| 80 |
+
"--format=csv,nounits,noheader",
|
| 81 |
+
"-i",
|
| 82 |
+
str(device_id),
|
| 83 |
+
],
|
| 84 |
+
encoding="utf-8",
|
| 85 |
+
)
|
| 86 |
+
return float(result.strip()) / 1024 # Convert MB to GB
|
| 87 |
+
except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
|
| 88 |
+
logger.error(f"Failed to get GPU memory from nvidia-smi: {e}")
|
| 89 |
+
# Fallback to torch
|
| 90 |
+
return torch.cuda.memory_allocated(device) / 1024**3
|
packages/ltx-trainer/src/ltx_trainer/progress.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Progress tracking for LTX training.
|
| 2 |
+
This module provides a unified progress display for training and validation sampling,
|
| 3 |
+
encapsulating all Rich progress bar logic in one place.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from rich.progress import (
|
| 7 |
+
BarColumn,
|
| 8 |
+
Progress,
|
| 9 |
+
TaskID,
|
| 10 |
+
TextColumn,
|
| 11 |
+
TimeElapsedColumn,
|
| 12 |
+
TimeRemainingColumn,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SamplingContext:
|
| 17 |
+
"""Context for validation sampling progress tracking.
|
| 18 |
+
Provides a unified progress display showing current video and denoising step.
|
| 19 |
+
Display format: "Sampling X/Y [████████████] step Z/W"
|
| 20 |
+
The progress bar shows the denoising progress for the current video.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int):
|
| 24 |
+
self._progress = progress
|
| 25 |
+
self._task = task
|
| 26 |
+
self._num_prompts = num_prompts
|
| 27 |
+
self._num_steps = num_steps
|
| 28 |
+
|
| 29 |
+
def start_video(self, video_idx: int) -> None:
|
| 30 |
+
"""Start tracking a new video (resets step progress)."""
|
| 31 |
+
if self._progress is None or self._task is None:
|
| 32 |
+
return
|
| 33 |
+
# Reset task for new video: completed=0, total=num_steps
|
| 34 |
+
self._progress.reset(self._task, total=self._num_steps)
|
| 35 |
+
self._progress.update(
|
| 36 |
+
self._task,
|
| 37 |
+
completed=0,
|
| 38 |
+
video=f"{video_idx + 1}/{self._num_prompts}",
|
| 39 |
+
info=f"step 0/{self._num_steps}",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def advance_step(self) -> None:
|
| 43 |
+
"""Advance the denoising step by one."""
|
| 44 |
+
if self._progress is None or self._task is None:
|
| 45 |
+
return
|
| 46 |
+
self._progress.advance(self._task)
|
| 47 |
+
completed = int(self._progress.tasks[self._task].completed)
|
| 48 |
+
self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")
|
| 49 |
+
|
| 50 |
+
def cleanup(self) -> None:
|
| 51 |
+
"""Hide sampling task when done."""
|
| 52 |
+
if self._progress is None or self._task is None:
|
| 53 |
+
return
|
| 54 |
+
self._progress.update(self._task, visible=False)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class StandaloneSamplingProgress:
|
| 58 |
+
"""Standalone progress display for inference scripts.
|
| 59 |
+
Unlike SamplingContext (which integrates with TrainingProgress), this class
|
| 60 |
+
manages its own Rich Progress instance for use in standalone inference scripts.
|
| 61 |
+
Usage:
|
| 62 |
+
with StandaloneSamplingProgress(num_steps=30) as ctx:
|
| 63 |
+
for step in range(30):
|
| 64 |
+
# ... denoising step ...
|
| 65 |
+
ctx.advance_step()
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, num_steps: int, description: str = "Generating"):
|
| 69 |
+
"""Initialize standalone sampling progress.
|
| 70 |
+
Args:
|
| 71 |
+
num_steps: Total number of denoising steps
|
| 72 |
+
description: Description to show in progress bar
|
| 73 |
+
"""
|
| 74 |
+
self._num_steps = num_steps
|
| 75 |
+
self._description = description
|
| 76 |
+
self._progress: Progress | None = None
|
| 77 |
+
self._task: TaskID | None = None
|
| 78 |
+
|
| 79 |
+
def __enter__(self) -> "StandaloneSamplingProgress":
|
| 80 |
+
"""Start the progress display."""
|
| 81 |
+
self._progress = Progress(
|
| 82 |
+
TextColumn("[progress.description]{task.description}"),
|
| 83 |
+
BarColumn(bar_width=40, style="blue"),
|
| 84 |
+
TextColumn("{task.fields[info]}", style="cyan"),
|
| 85 |
+
TimeElapsedColumn(),
|
| 86 |
+
TextColumn("ETA:"),
|
| 87 |
+
TimeRemainingColumn(compact=True),
|
| 88 |
+
)
|
| 89 |
+
self._progress.__enter__()
|
| 90 |
+
self._task = self._progress.add_task(
|
| 91 |
+
self._description,
|
| 92 |
+
total=self._num_steps,
|
| 93 |
+
info=f"step 0/{self._num_steps}",
|
| 94 |
+
)
|
| 95 |
+
return self
|
| 96 |
+
|
| 97 |
+
def __exit__(self, *args) -> None:
|
| 98 |
+
"""Stop the progress display."""
|
| 99 |
+
if self._progress is not None:
|
| 100 |
+
self._progress.__exit__(*args)
|
| 101 |
+
|
| 102 |
+
def advance_step(self) -> None:
|
| 103 |
+
"""Advance the denoising step by one."""
|
| 104 |
+
if self._progress is None or self._task is None:
|
| 105 |
+
return
|
| 106 |
+
self._progress.advance(self._task)
|
| 107 |
+
completed = int(self._progress.tasks[self._task].completed)
|
| 108 |
+
self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TrainingProgress:
|
| 112 |
+
"""Manages Rich progress display for training and validation.
|
| 113 |
+
This class encapsulates all progress bar logic, providing a clean interface
|
| 114 |
+
for the trainer to update progress without dealing with Rich internals.
|
| 115 |
+
Usage:
|
| 116 |
+
with TrainingProgress(enabled=True, total_steps=1000) as progress:
|
| 117 |
+
for step in range(1000):
|
| 118 |
+
# ... training step ...
|
| 119 |
+
progress.update_training(loss=0.1, lr=1e-4, step_time=0.5)
|
| 120 |
+
if should_validate:
|
| 121 |
+
sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30)
|
| 122 |
+
sampler = ValidationSampler(..., sampling_context=sampling_ctx)
|
| 123 |
+
for prompt_idx, prompt in enumerate(prompts):
|
| 124 |
+
sampling_ctx.start_video(prompt_idx)
|
| 125 |
+
sampler.generate(...)
|
| 126 |
+
sampling_ctx.cleanup()
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, enabled: bool, total_steps: int):
|
| 130 |
+
"""Initialize progress tracking.
|
| 131 |
+
Args:
|
| 132 |
+
enabled: Whether to display progress bars (False for non-main processes)
|
| 133 |
+
total_steps: Total number of training steps
|
| 134 |
+
"""
|
| 135 |
+
self._enabled = enabled
|
| 136 |
+
self._total_steps = total_steps
|
| 137 |
+
self._train_task: TaskID | None = None
|
| 138 |
+
|
| 139 |
+
if not enabled:
|
| 140 |
+
self._progress = None
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
# Single Progress instance with flexible columns
|
| 144 |
+
self._progress = Progress(
|
| 145 |
+
TextColumn("[progress.description]{task.description}"),
|
| 146 |
+
TextColumn("{task.fields[video]}", style="magenta"),
|
| 147 |
+
BarColumn(bar_width=40, style="blue"),
|
| 148 |
+
TextColumn("{task.fields[info]}", style="cyan"),
|
| 149 |
+
TimeElapsedColumn(),
|
| 150 |
+
TextColumn("ETA:"),
|
| 151 |
+
TimeRemainingColumn(compact=True),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def __enter__(self) -> "TrainingProgress":
|
| 155 |
+
"""Enter the progress context, starting the live display."""
|
| 156 |
+
if self._progress is not None:
|
| 157 |
+
self._progress.__enter__()
|
| 158 |
+
self._train_task = self._progress.add_task(
|
| 159 |
+
"Training",
|
| 160 |
+
total=self._total_steps,
|
| 161 |
+
video=f"0/{self._total_steps}",
|
| 162 |
+
info="Starting...",
|
| 163 |
+
)
|
| 164 |
+
return self
|
| 165 |
+
|
| 166 |
+
def __exit__(self, *args) -> None:
|
| 167 |
+
"""Exit the progress context, stopping the live display."""
|
| 168 |
+
if self._progress is not None:
|
| 169 |
+
self._progress.__exit__(*args)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def enabled(self) -> bool:
|
| 173 |
+
"""Whether progress display is enabled."""
|
| 174 |
+
return self._enabled
|
| 175 |
+
|
| 176 |
+
def update_training(
|
| 177 |
+
self,
|
| 178 |
+
*,
|
| 179 |
+
loss: float,
|
| 180 |
+
lr: float,
|
| 181 |
+
step_time: float,
|
| 182 |
+
advance: bool = True,
|
| 183 |
+
) -> None:
|
| 184 |
+
"""Update the training progress display.
|
| 185 |
+
Args:
|
| 186 |
+
loss: Current training loss
|
| 187 |
+
lr: Current learning rate
|
| 188 |
+
step_time: Time taken for this step in seconds
|
| 189 |
+
advance: Whether to advance the progress by one step
|
| 190 |
+
"""
|
| 191 |
+
if self._progress is None or self._train_task is None:
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step"
|
| 195 |
+
self._progress.update(
|
| 196 |
+
self._train_task,
|
| 197 |
+
advance=1 if advance else 0,
|
| 198 |
+
info=info,
|
| 199 |
+
)
|
| 200 |
+
# Update step count in video column
|
| 201 |
+
completed = int(self._progress.tasks[self._train_task].completed)
|
| 202 |
+
self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}")
|
| 203 |
+
|
| 204 |
+
def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext:
|
| 205 |
+
"""Start validation sampling progress tracking.
|
| 206 |
+
Creates a task that shows current video and denoising step progress.
|
| 207 |
+
Format: "Sampling X/Y [████████████] step Z/W"
|
| 208 |
+
Args:
|
| 209 |
+
num_prompts: Number of validation prompts to sample
|
| 210 |
+
num_steps: Number of denoising steps per sample
|
| 211 |
+
Returns:
|
| 212 |
+
SamplingContext for tracking progress (no-op if progress is disabled)
|
| 213 |
+
"""
|
| 214 |
+
if self._progress is None:
|
| 215 |
+
# Return a no-op context when progress is disabled
|
| 216 |
+
return SamplingContext(
|
| 217 |
+
progress=None,
|
| 218 |
+
task=None,
|
| 219 |
+
num_prompts=num_prompts,
|
| 220 |
+
num_steps=num_steps,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
task = self._progress.add_task(
|
| 224 |
+
"Sampling",
|
| 225 |
+
total=num_steps,
|
| 226 |
+
completed=0,
|
| 227 |
+
video=f"0/{num_prompts}",
|
| 228 |
+
info=f"step 0/{num_steps}",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return SamplingContext(
|
| 232 |
+
progress=self._progress,
|
| 233 |
+
task=task,
|
| 234 |
+
num_prompts=num_prompts,
|
| 235 |
+
num_steps=num_steps,
|
| 236 |
+
)
|
packages/ltx-trainer/src/ltx_trainer/quantization.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/bghira/SimpleTuner
|
| 2 |
+
# With improvements from: https://github.com/ostris/ai-toolkit
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
|
| 7 |
+
|
| 8 |
+
from ltx_trainer import logger
|
| 9 |
+
|
| 10 |
+
QuantizationOptions = Literal[
|
| 11 |
+
"int8-quanto",
|
| 12 |
+
"int4-quanto",
|
| 13 |
+
"int2-quanto",
|
| 14 |
+
"fp8-quanto",
|
| 15 |
+
"fp8uz-quanto",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
# Modules to exclude from quantization.
|
| 19 |
+
# These are glob patterns passed to quanto's `exclude` parameter.
|
| 20 |
+
# When quantizing the full model at once, these patterns match against full module paths.
|
| 21 |
+
# When quantizing block-by-block, we also use SKIP_ROOT_MODULES for top-level modules.
|
| 22 |
+
EXCLUDE_PATTERNS = [
|
| 23 |
+
# Input/output projection layers
|
| 24 |
+
"patchify_proj",
|
| 25 |
+
"audio_patchify_proj",
|
| 26 |
+
"proj_out",
|
| 27 |
+
"audio_proj_out",
|
| 28 |
+
# Timestep embedding layers - int4 tinygemm requires strict bfloat16 input
|
| 29 |
+
# and these receive float32 sinusoidal embeddings that are cast to bfloat16
|
| 30 |
+
"*adaln*",
|
| 31 |
+
"time_proj",
|
| 32 |
+
"timestep_embedder*",
|
| 33 |
+
# Caption/text projection layers
|
| 34 |
+
"caption_projection*",
|
| 35 |
+
"audio_caption_projection*",
|
| 36 |
+
# Normalization layers (usually excluded from quantization)
|
| 37 |
+
"*norm*",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Top-level modules to skip entirely during block-by-block quantization.
|
| 41 |
+
# These are exact matches against model.named_children() names.
|
| 42 |
+
# (Needed because quanto's exclude patterns don't work when calling quantize() directly on a module)
|
| 43 |
+
SKIP_ROOT_MODULES = {
|
| 44 |
+
"patchify_proj",
|
| 45 |
+
"audio_patchify_proj",
|
| 46 |
+
"proj_out",
|
| 47 |
+
"audio_proj_out",
|
| 48 |
+
"audio_caption_projection",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def quantize_model(
|
| 53 |
+
model: torch.nn.Module,
|
| 54 |
+
precision: QuantizationOptions,
|
| 55 |
+
quantize_activations: bool = False,
|
| 56 |
+
device: torch.device | str | None = None,
|
| 57 |
+
) -> torch.nn.Module:
|
| 58 |
+
"""
|
| 59 |
+
Quantize a model using optimum-quanto.
|
| 60 |
+
For large models with transformer_blocks, this function quantizes block-by-block
|
| 61 |
+
on GPU then moves back to CPU, which is much faster than quantizing on CPU and
|
| 62 |
+
uses less peak VRAM than loading the entire model to GPU at once.
|
| 63 |
+
Args:
|
| 64 |
+
model: The model to quantize.
|
| 65 |
+
precision: The quantization precision (e.g. "int8-quanto", "fp8-quanto").
|
| 66 |
+
quantize_activations: Whether to quantize activations in addition to weights.
|
| 67 |
+
device: Device to use for quantization. If None, uses CUDA if available, else CPU.
|
| 68 |
+
Returns:
|
| 69 |
+
The quantized model.
|
| 70 |
+
"""
|
| 71 |
+
from optimum.quanto import freeze, quantize # noqa: PLC0415
|
| 72 |
+
|
| 73 |
+
if device is None:
|
| 74 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 75 |
+
elif isinstance(device, str):
|
| 76 |
+
device = torch.device(device)
|
| 77 |
+
|
| 78 |
+
weight_quant = _get_quanto_dtype(precision)
|
| 79 |
+
|
| 80 |
+
if quantize_activations:
|
| 81 |
+
logger.debug("Quantizing model weights and activations")
|
| 82 |
+
activations_quant = weight_quant
|
| 83 |
+
else:
|
| 84 |
+
activations_quant = None
|
| 85 |
+
|
| 86 |
+
# Remember original device to restore after quantization
|
| 87 |
+
original_device = next(model.parameters()).device
|
| 88 |
+
|
| 89 |
+
# Check if model has transformer_blocks for block-by-block quantization
|
| 90 |
+
if hasattr(model, "transformer_blocks"):
|
| 91 |
+
logger.debug("Quantizing model using block-by-block approach for memory efficiency")
|
| 92 |
+
_quantize_blockwise(
|
| 93 |
+
model,
|
| 94 |
+
weight_quant=weight_quant,
|
| 95 |
+
activations_quant=activations_quant,
|
| 96 |
+
device=device,
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
# Fallback: quantize entire model at once
|
| 100 |
+
model.to(device)
|
| 101 |
+
quantize(model, weights=weight_quant, activations=activations_quant, exclude=EXCLUDE_PATTERNS)
|
| 102 |
+
freeze(model)
|
| 103 |
+
|
| 104 |
+
# Restore model to original device
|
| 105 |
+
model.to(original_device)
|
| 106 |
+
|
| 107 |
+
return model
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _quantize_blockwise(
|
| 111 |
+
model: torch.nn.Module,
|
| 112 |
+
weight_quant: torch.dtype,
|
| 113 |
+
activations_quant: torch.dtype | None,
|
| 114 |
+
device: torch.device,
|
| 115 |
+
) -> None:
|
| 116 |
+
"""Quantize a model block-by-block using optimum-quanto.
|
| 117 |
+
This approach:
|
| 118 |
+
1. Moves each transformer block to GPU
|
| 119 |
+
2. Quantizes on GPU (fast!)
|
| 120 |
+
3. Freezes the quantized weights
|
| 121 |
+
4. Moves back to CPU
|
| 122 |
+
This is much faster than quantizing on CPU and uses less peak VRAM
|
| 123 |
+
than loading the entire model to GPU.
|
| 124 |
+
"""
|
| 125 |
+
from optimum.quanto import freeze, quantize # noqa: PLC0415
|
| 126 |
+
|
| 127 |
+
original_dtype = next(model.parameters()).dtype
|
| 128 |
+
transformer_blocks = list(model.transformer_blocks)
|
| 129 |
+
|
| 130 |
+
with Progress(
|
| 131 |
+
SpinnerColumn(),
|
| 132 |
+
TextColumn("[progress.description]{task.description}"),
|
| 133 |
+
BarColumn(),
|
| 134 |
+
TaskProgressColumn(),
|
| 135 |
+
transient=True,
|
| 136 |
+
) as progress:
|
| 137 |
+
task = progress.add_task("Quantizing transformer blocks", total=len(transformer_blocks))
|
| 138 |
+
|
| 139 |
+
for block in transformer_blocks:
|
| 140 |
+
# Move block to GPU
|
| 141 |
+
block.to(device, dtype=original_dtype, non_blocking=True)
|
| 142 |
+
|
| 143 |
+
# Quantize on GPU
|
| 144 |
+
quantize(block, weights=weight_quant, activations=activations_quant, exclude=EXCLUDE_PATTERNS)
|
| 145 |
+
freeze(block)
|
| 146 |
+
|
| 147 |
+
# Move back to CPU to free up VRAM for next block
|
| 148 |
+
block.to("cpu", non_blocking=True)
|
| 149 |
+
|
| 150 |
+
progress.advance(task)
|
| 151 |
+
|
| 152 |
+
# Quantize remaining non-transformer-block modules (e.g., embeddings, timestep projections)
|
| 153 |
+
# Skip modules that should not be quantized (patchify_proj, proj_out, etc.)
|
| 154 |
+
logger.debug("Quantizing remaining model components")
|
| 155 |
+
|
| 156 |
+
for name, module in model.named_children():
|
| 157 |
+
if name == "transformer_blocks":
|
| 158 |
+
continue # Already quantized
|
| 159 |
+
|
| 160 |
+
if name in SKIP_ROOT_MODULES:
|
| 161 |
+
logger.debug(f"Skipping quantization for module: {name}")
|
| 162 |
+
continue # Don't quantize these modules
|
| 163 |
+
|
| 164 |
+
# Move to device, quantize, freeze, move back
|
| 165 |
+
module.to(device, dtype=original_dtype, non_blocking=True)
|
| 166 |
+
quantize(module, weights=weight_quant, activations=activations_quant, exclude=EXCLUDE_PATTERNS)
|
| 167 |
+
freeze(module)
|
| 168 |
+
module.to("cpu", non_blocking=True)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _get_quanto_dtype(precision: QuantizationOptions) -> torch.dtype:
|
| 172 |
+
"""Map precision string to quanto dtype."""
|
| 173 |
+
from optimum.quanto import ( # noqa: PLC0415
|
| 174 |
+
qfloat8,
|
| 175 |
+
qfloat8_e4m3fnuz,
|
| 176 |
+
qint2,
|
| 177 |
+
qint4,
|
| 178 |
+
qint8,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if precision == "int2-quanto":
|
| 182 |
+
return qint2
|
| 183 |
+
elif precision == "int4-quanto":
|
| 184 |
+
return qint4
|
| 185 |
+
elif precision == "int8-quanto":
|
| 186 |
+
return qint8
|
| 187 |
+
elif precision in ("fp8-quanto", "fp8uz-quanto"):
|
| 188 |
+
if torch.backends.mps.is_available():
|
| 189 |
+
raise ValueError("FP8 quantization is not supported on MPS devices. Use int2, int4, or int8 instead.")
|
| 190 |
+
if precision == "fp8-quanto":
|
| 191 |
+
return qfloat8
|
| 192 |
+
elif precision == "fp8uz-quanto":
|
| 193 |
+
return qfloat8_e4m3fnuz
|
| 194 |
+
|
| 195 |
+
raise ValueError(f"Invalid quantization precision: {precision}")
|
packages/ltx-trainer/src/ltx_trainer/trainer.py
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import warnings
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import wandb
|
| 9 |
+
import yaml
|
| 10 |
+
from accelerate import Accelerator, DistributedType
|
| 11 |
+
from accelerate.utils import set_seed
|
| 12 |
+
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
|
| 13 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
| 14 |
+
from peft.utils import ModulesToSaveWrapper
|
| 15 |
+
from pydantic import BaseModel
|
| 16 |
+
from safetensors.torch import load_file, save_file
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from torch.optim import AdamW
|
| 19 |
+
from torch.optim.lr_scheduler import (
|
| 20 |
+
CosineAnnealingLR,
|
| 21 |
+
CosineAnnealingWarmRestarts,
|
| 22 |
+
LinearLR,
|
| 23 |
+
LRScheduler,
|
| 24 |
+
PolynomialLR,
|
| 25 |
+
StepLR,
|
| 26 |
+
)
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
from torchvision.transforms import functional as F # noqa: N812
|
| 29 |
+
|
| 30 |
+
from ltx_core.text_encoders.gemma import convert_to_additive_mask
|
| 31 |
+
from ltx_trainer import logger
|
| 32 |
+
from ltx_trainer.config import LtxTrainerConfig
|
| 33 |
+
from ltx_trainer.config_display import print_config
|
| 34 |
+
from ltx_trainer.datasets import PrecomputedDataset
|
| 35 |
+
from ltx_trainer.gpu_utils import free_gpu_memory, free_gpu_memory_context, get_gpu_memory_gb
|
| 36 |
+
from ltx_trainer.hf_hub_utils import push_to_hub
|
| 37 |
+
from ltx_trainer.model_loader import load_embeddings_processor, load_text_encoder
|
| 38 |
+
from ltx_trainer.model_loader import load_model as load_ltx_model
|
| 39 |
+
from ltx_trainer.progress import TrainingProgress
|
| 40 |
+
from ltx_trainer.quantization import quantize_model
|
| 41 |
+
from ltx_trainer.timestep_samplers import SAMPLERS
|
| 42 |
+
from ltx_trainer.training_strategies import get_training_strategy
|
| 43 |
+
from ltx_trainer.utils import open_image_as_srgb, save_image
|
| 44 |
+
from ltx_trainer.validation_sampler import CachedPromptEmbeddings, GenerationConfig, ValidationSampler
|
| 45 |
+
from ltx_trainer.video_utils import read_video, save_video
|
| 46 |
+
|
| 47 |
+
# Disable irrelevant warnings from transformers
|
| 48 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 49 |
+
|
| 50 |
+
# Silence bitsandbytes warnings about casting
|
| 51 |
+
warnings.filterwarnings(
|
| 52 |
+
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Disable progress bars if not main process
|
| 56 |
+
IS_MAIN_PROCESS = os.environ.get("LOCAL_RANK", "0") == "0"
|
| 57 |
+
if not IS_MAIN_PROCESS:
|
| 58 |
+
from transformers.utils.logging import disable_progress_bar
|
| 59 |
+
|
| 60 |
+
disable_progress_bar()
|
| 61 |
+
|
| 62 |
+
StepCallback = Callable[[int, int, list[Path]], None] # (step, total, list[sampled_video_path]) -> None
|
| 63 |
+
|
| 64 |
+
MEMORY_CHECK_INTERVAL = 200
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TrainingStats(BaseModel):
|
| 68 |
+
"""Statistics collected during training"""
|
| 69 |
+
|
| 70 |
+
total_time_seconds: float
|
| 71 |
+
steps_per_second: float
|
| 72 |
+
samples_per_second: float
|
| 73 |
+
peak_gpu_memory_gb: float
|
| 74 |
+
global_batch_size: int
|
| 75 |
+
num_processes: int
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LtxvTrainer:
|
| 79 |
+
def __init__(self, trainer_config: LtxTrainerConfig) -> None:
|
| 80 |
+
self._config = trainer_config
|
| 81 |
+
if IS_MAIN_PROCESS:
|
| 82 |
+
print_config(trainer_config)
|
| 83 |
+
self._training_strategy = get_training_strategy(self._config.training_strategy)
|
| 84 |
+
self._cached_validation_embeddings = self._load_text_encoder_and_cache_embeddings()
|
| 85 |
+
self._load_models()
|
| 86 |
+
self._setup_accelerator()
|
| 87 |
+
self._collect_trainable_params()
|
| 88 |
+
self._load_checkpoint()
|
| 89 |
+
self._prepare_models_for_training()
|
| 90 |
+
self._dataset = None
|
| 91 |
+
self._global_step = -1
|
| 92 |
+
self._checkpoint_paths = []
|
| 93 |
+
self._init_wandb()
|
| 94 |
+
|
| 95 |
+
def train( # noqa: PLR0912, PLR0915
|
| 96 |
+
self,
|
| 97 |
+
disable_progress_bars: bool = False,
|
| 98 |
+
step_callback: StepCallback | None = None,
|
| 99 |
+
) -> tuple[Path, TrainingStats]:
|
| 100 |
+
"""
|
| 101 |
+
Start the training process.
|
| 102 |
+
Returns:
|
| 103 |
+
Tuple of (saved_model_path, training_stats)
|
| 104 |
+
"""
|
| 105 |
+
device = self._accelerator.device
|
| 106 |
+
cfg = self._config
|
| 107 |
+
start_mem = get_gpu_memory_gb(device)
|
| 108 |
+
|
| 109 |
+
train_start_time = time.time()
|
| 110 |
+
|
| 111 |
+
# Use the same seed for all processes and ensure deterministic operations
|
| 112 |
+
set_seed(cfg.seed)
|
| 113 |
+
logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}")
|
| 114 |
+
|
| 115 |
+
self._init_optimizer()
|
| 116 |
+
self._init_dataloader()
|
| 117 |
+
data_iter = iter(self._dataloader)
|
| 118 |
+
self._init_timestep_sampler()
|
| 119 |
+
|
| 120 |
+
# Synchronize all processes after initialization
|
| 121 |
+
self._accelerator.wait_for_everyone()
|
| 122 |
+
|
| 123 |
+
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
|
| 124 |
+
|
| 125 |
+
# Save the training configuration as YAML
|
| 126 |
+
self._save_config()
|
| 127 |
+
|
| 128 |
+
logger.info("🚀 Starting training...")
|
| 129 |
+
|
| 130 |
+
# Create progress tracking (disabled for non-main processes or when explicitly disabled)
|
| 131 |
+
progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars
|
| 132 |
+
progress = TrainingProgress(
|
| 133 |
+
enabled=progress_enabled,
|
| 134 |
+
total_steps=cfg.optimization.steps,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if IS_MAIN_PROCESS and disable_progress_bars:
|
| 138 |
+
logger.warning("Progress bars disabled. Intermediate status messages will be logged instead.")
|
| 139 |
+
|
| 140 |
+
self._transformer.train()
|
| 141 |
+
self._global_step = 0
|
| 142 |
+
|
| 143 |
+
peak_mem_during_training = start_mem
|
| 144 |
+
|
| 145 |
+
sampled_videos_paths = None
|
| 146 |
+
|
| 147 |
+
with progress:
|
| 148 |
+
# Initial validation before training starts
|
| 149 |
+
if cfg.validation.interval and not cfg.validation.skip_initial_validation:
|
| 150 |
+
sampled_videos_paths = self._sample_videos(progress)
|
| 151 |
+
if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos:
|
| 152 |
+
self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts)
|
| 153 |
+
|
| 154 |
+
self._accelerator.wait_for_everyone()
|
| 155 |
+
|
| 156 |
+
for step in range(cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps):
|
| 157 |
+
# Get next batch, reset the dataloader if needed
|
| 158 |
+
try:
|
| 159 |
+
batch = next(data_iter)
|
| 160 |
+
except StopIteration:
|
| 161 |
+
data_iter = iter(self._dataloader)
|
| 162 |
+
batch = next(data_iter)
|
| 163 |
+
|
| 164 |
+
step_start_time = time.time()
|
| 165 |
+
with self._accelerator.accumulate(self._transformer):
|
| 166 |
+
is_optimization_step = (step + 1) % cfg.optimization.gradient_accumulation_steps == 0
|
| 167 |
+
if is_optimization_step:
|
| 168 |
+
self._global_step += 1
|
| 169 |
+
|
| 170 |
+
loss = self._training_step(batch)
|
| 171 |
+
self._accelerator.backward(loss)
|
| 172 |
+
|
| 173 |
+
if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0:
|
| 174 |
+
self._accelerator.clip_grad_norm_(
|
| 175 |
+
self._trainable_params,
|
| 176 |
+
cfg.optimization.max_grad_norm,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self._optimizer.step()
|
| 180 |
+
self._optimizer.zero_grad()
|
| 181 |
+
|
| 182 |
+
if self._lr_scheduler is not None:
|
| 183 |
+
self._lr_scheduler.step()
|
| 184 |
+
|
| 185 |
+
# Run validation if needed
|
| 186 |
+
if (
|
| 187 |
+
cfg.validation.interval
|
| 188 |
+
and self._global_step > 0
|
| 189 |
+
and self._global_step % cfg.validation.interval == 0
|
| 190 |
+
and is_optimization_step
|
| 191 |
+
):
|
| 192 |
+
if self._accelerator.distributed_type == DistributedType.FSDP:
|
| 193 |
+
# FSDP: All processes must participate in validation
|
| 194 |
+
sampled_videos_paths = self._sample_videos(progress)
|
| 195 |
+
if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos:
|
| 196 |
+
self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts)
|
| 197 |
+
# DDP: Only main process runs validation
|
| 198 |
+
elif IS_MAIN_PROCESS:
|
| 199 |
+
sampled_videos_paths = self._sample_videos(progress)
|
| 200 |
+
if sampled_videos_paths and self._config.wandb.log_validation_videos:
|
| 201 |
+
self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts)
|
| 202 |
+
|
| 203 |
+
# Save checkpoint if needed
|
| 204 |
+
if (
|
| 205 |
+
cfg.checkpoints.interval
|
| 206 |
+
and self._global_step > 0
|
| 207 |
+
and self._global_step % cfg.checkpoints.interval == 0
|
| 208 |
+
and is_optimization_step
|
| 209 |
+
):
|
| 210 |
+
self._save_checkpoint()
|
| 211 |
+
|
| 212 |
+
self._accelerator.wait_for_everyone()
|
| 213 |
+
|
| 214 |
+
# Call step callback if provided
|
| 215 |
+
if step_callback and is_optimization_step:
|
| 216 |
+
step_callback(self._global_step, cfg.optimization.steps, sampled_videos_paths)
|
| 217 |
+
|
| 218 |
+
self._accelerator.wait_for_everyone()
|
| 219 |
+
|
| 220 |
+
# Update progress and log metrics
|
| 221 |
+
current_lr = self._optimizer.param_groups[0]["lr"]
|
| 222 |
+
step_time = (time.time() - step_start_time) * cfg.optimization.gradient_accumulation_steps
|
| 223 |
+
|
| 224 |
+
progress.update_training(
|
| 225 |
+
loss=loss.item(),
|
| 226 |
+
lr=current_lr,
|
| 227 |
+
step_time=step_time,
|
| 228 |
+
advance=is_optimization_step,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Log metrics to W&B (only on main process and optimization steps)
|
| 232 |
+
if IS_MAIN_PROCESS and is_optimization_step:
|
| 233 |
+
self._log_metrics(
|
| 234 |
+
{
|
| 235 |
+
"train/loss": loss.item(),
|
| 236 |
+
"train/learning_rate": current_lr,
|
| 237 |
+
"train/step_time": step_time,
|
| 238 |
+
"train/global_step": self._global_step,
|
| 239 |
+
}
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Fallback logging when progress bars are disabled
|
| 243 |
+
if disable_progress_bars and IS_MAIN_PROCESS and self._global_step % 20 == 0:
|
| 244 |
+
elapsed = time.time() - train_start_time
|
| 245 |
+
progress_percentage = self._global_step / cfg.optimization.steps
|
| 246 |
+
if progress_percentage > 0:
|
| 247 |
+
total_estimated = elapsed / progress_percentage
|
| 248 |
+
total_time = f"{total_estimated // 3600:.0f}h {(total_estimated % 3600) // 60:.0f}m"
|
| 249 |
+
else:
|
| 250 |
+
total_time = "calculating..."
|
| 251 |
+
logger.info(
|
| 252 |
+
f"Step {self._global_step}/{cfg.optimization.steps} - "
|
| 253 |
+
f"Loss: {loss.item():.4f}, LR: {current_lr:.2e}, "
|
| 254 |
+
f"Time/Step: {step_time:.2f}s, Total Time: {total_time}",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Sample GPU memory periodically
|
| 258 |
+
if step % MEMORY_CHECK_INTERVAL == 0:
|
| 259 |
+
current_mem = get_gpu_memory_gb(device)
|
| 260 |
+
peak_mem_during_training = max(peak_mem_during_training, current_mem)
|
| 261 |
+
|
| 262 |
+
# Collect final stats
|
| 263 |
+
train_end_time = time.time()
|
| 264 |
+
end_mem = get_gpu_memory_gb(device)
|
| 265 |
+
peak_mem = max(start_mem, end_mem, peak_mem_during_training)
|
| 266 |
+
|
| 267 |
+
# Calculate steps/second over entire training
|
| 268 |
+
total_time_seconds = train_end_time - train_start_time
|
| 269 |
+
steps_per_second = cfg.optimization.steps / total_time_seconds
|
| 270 |
+
|
| 271 |
+
samples_per_second = steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size
|
| 272 |
+
|
| 273 |
+
stats = TrainingStats(
|
| 274 |
+
total_time_seconds=total_time_seconds,
|
| 275 |
+
steps_per_second=steps_per_second,
|
| 276 |
+
samples_per_second=samples_per_second,
|
| 277 |
+
peak_gpu_memory_gb=peak_mem,
|
| 278 |
+
num_processes=self._accelerator.num_processes,
|
| 279 |
+
global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
saved_path = self._save_checkpoint()
|
| 283 |
+
|
| 284 |
+
if IS_MAIN_PROCESS:
|
| 285 |
+
# Log the training statistics
|
| 286 |
+
self._log_training_stats(stats)
|
| 287 |
+
|
| 288 |
+
# Upload artifacts to hub if enabled
|
| 289 |
+
if cfg.hub.push_to_hub:
|
| 290 |
+
push_to_hub(saved_path, sampled_videos_paths, self._config)
|
| 291 |
+
|
| 292 |
+
# Log final stats to W&B
|
| 293 |
+
if self._wandb_run is not None:
|
| 294 |
+
self._log_metrics(
|
| 295 |
+
{
|
| 296 |
+
"stats/total_time_minutes": stats.total_time_seconds / 60,
|
| 297 |
+
"stats/steps_per_second": stats.steps_per_second,
|
| 298 |
+
"stats/samples_per_second": stats.samples_per_second,
|
| 299 |
+
"stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb,
|
| 300 |
+
}
|
| 301 |
+
)
|
| 302 |
+
self._wandb_run.finish()
|
| 303 |
+
|
| 304 |
+
self._accelerator.wait_for_everyone()
|
| 305 |
+
self._accelerator.end_training()
|
| 306 |
+
|
| 307 |
+
return saved_path, stats
|
| 308 |
+
|
| 309 |
+
def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor:
|
| 310 |
+
"""Perform a single training step using the configured strategy."""
|
| 311 |
+
# Apply embedding connectors to transform pre-computed text embeddings
|
| 312 |
+
conditions = batch["conditions"]
|
| 313 |
+
|
| 314 |
+
if "video_prompt_embeds" in conditions:
|
| 315 |
+
# New format: separate video/audio features from precompute()
|
| 316 |
+
video_features = conditions["video_prompt_embeds"]
|
| 317 |
+
audio_features = conditions.get("audio_prompt_embeds")
|
| 318 |
+
else:
|
| 319 |
+
# Legacy format: single prompt_embeds tensor — duplicate for both modalities
|
| 320 |
+
video_features = conditions["prompt_embeds"]
|
| 321 |
+
audio_features = conditions["prompt_embeds"]
|
| 322 |
+
|
| 323 |
+
mask = conditions["prompt_attention_mask"]
|
| 324 |
+
additive_mask = convert_to_additive_mask(mask, video_features.dtype)
|
| 325 |
+
video_embeds, audio_embeds, attention_mask = self._embeddings_processor.create_embeddings(
|
| 326 |
+
video_features, audio_features, additive_mask
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
conditions["video_prompt_embeds"] = video_embeds
|
| 330 |
+
conditions["audio_prompt_embeds"] = audio_embeds
|
| 331 |
+
conditions["prompt_attention_mask"] = attention_mask
|
| 332 |
+
|
| 333 |
+
# Use strategy to prepare training inputs (returns ModelInputs with Modality objects)
|
| 334 |
+
model_inputs = self._training_strategy.prepare_training_inputs(batch, self._timestep_sampler)
|
| 335 |
+
|
| 336 |
+
# Run transformer forward pass with Modality-based interface
|
| 337 |
+
video_pred, audio_pred = self._transformer(
|
| 338 |
+
video=model_inputs.video,
|
| 339 |
+
audio=model_inputs.audio,
|
| 340 |
+
perturbations=None,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Use strategy to compute loss
|
| 344 |
+
loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs)
|
| 345 |
+
|
| 346 |
+
return loss
|
| 347 |
+
|
| 348 |
+
@free_gpu_memory_context(after=True)
|
| 349 |
+
def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings] | None:
|
| 350 |
+
"""Load text encoder + embeddings processor, compute and cache validation embeddings."""
|
| 351 |
+
|
| 352 |
+
# This method:
|
| 353 |
+
# 1. Loads the pure Gemma text encoder on GPU
|
| 354 |
+
# 2. Loads the embeddings processor (feature extractor + connectors)
|
| 355 |
+
# 3. If validation prompts are configured, computes and caches their embeddings
|
| 356 |
+
# 4. Unloads the Gemma model entirely, keeps the embeddings processor for training
|
| 357 |
+
|
| 358 |
+
# Load text encoder (pure Gemma LLM) on GPU
|
| 359 |
+
logger.debug("Loading text encoder...")
|
| 360 |
+
text_encoder = load_text_encoder(
|
| 361 |
+
gemma_model_path=self._config.model.text_encoder_path,
|
| 362 |
+
device="cuda",
|
| 363 |
+
dtype=torch.bfloat16,
|
| 364 |
+
load_in_8bit=self._config.acceleration.load_text_encoder_in_8bit,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Load embeddings processor (feature extractor + connectors)
|
| 368 |
+
logger.debug("Loading embeddings processor...")
|
| 369 |
+
self._embeddings_processor = load_embeddings_processor(
|
| 370 |
+
checkpoint_path=self._config.model.model_path,
|
| 371 |
+
device="cuda",
|
| 372 |
+
dtype=torch.bfloat16,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Cache validation embeddings if prompts are configured
|
| 376 |
+
cached_embeddings = None
|
| 377 |
+
if self._config.validation.prompts:
|
| 378 |
+
logger.info(f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts...")
|
| 379 |
+
cached_embeddings = []
|
| 380 |
+
with torch.inference_mode():
|
| 381 |
+
for prompt in self._config.validation.prompts:
|
| 382 |
+
pos_hs, pos_mask = text_encoder.encode(prompt)
|
| 383 |
+
pos_out = self._embeddings_processor.process_hidden_states(pos_hs, pos_mask)
|
| 384 |
+
|
| 385 |
+
neg_hs, neg_mask = text_encoder.encode(self._config.validation.negative_prompt)
|
| 386 |
+
neg_out = self._embeddings_processor.process_hidden_states(neg_hs, neg_mask)
|
| 387 |
+
|
| 388 |
+
cached_embeddings.append(
|
| 389 |
+
CachedPromptEmbeddings(
|
| 390 |
+
video_context_positive=pos_out.video_encoding.cpu(),
|
| 391 |
+
audio_context_positive=pos_out.audio_encoding.cpu(),
|
| 392 |
+
video_context_negative=neg_out.video_encoding.cpu(),
|
| 393 |
+
audio_context_negative=(
|
| 394 |
+
neg_out.audio_encoding.cpu() if neg_out.audio_encoding is not None else None
|
| 395 |
+
),
|
| 396 |
+
)
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
# Unload Gemma model and feature extractor, keep only connectors for training
|
| 400 |
+
del text_encoder
|
| 401 |
+
self._embeddings_processor.feature_extractor = None
|
| 402 |
+
|
| 403 |
+
logger.debug("Validation prompt embeddings cached. Gemma model unloaded")
|
| 404 |
+
return cached_embeddings
|
| 405 |
+
|
| 406 |
+
def _load_models(self) -> None:
|
| 407 |
+
"""Load the LTX-2 model components."""
|
| 408 |
+
# Load audio components if:
|
| 409 |
+
# 1. Training strategy requires audio (training the audio branch), OR
|
| 410 |
+
# 2. Validation is configured to generate audio (even if not training audio)
|
| 411 |
+
load_audio = self._training_strategy.requires_audio or self._config.validation.generate_audio
|
| 412 |
+
|
| 413 |
+
# Check if we need VAE encoder (for image or reference video conditioning)
|
| 414 |
+
need_vae_encoder = (
|
| 415 |
+
self._config.validation.images is not None or self._config.validation.reference_videos is not None
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Load all model components (except text encoder - already handled)
|
| 419 |
+
components = load_ltx_model(
|
| 420 |
+
checkpoint_path=self._config.model.model_path,
|
| 421 |
+
device="cpu",
|
| 422 |
+
dtype=torch.bfloat16,
|
| 423 |
+
with_video_vae_encoder=need_vae_encoder, # Needed for image conditioning
|
| 424 |
+
with_video_vae_decoder=True, # Needed for validation sampling
|
| 425 |
+
with_audio_vae_decoder=load_audio,
|
| 426 |
+
with_vocoder=load_audio,
|
| 427 |
+
with_text_encoder=False, # Text encoder handled separately
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Extract components
|
| 431 |
+
self._transformer = components.transformer
|
| 432 |
+
self._vae_decoder = components.video_vae_decoder.to(dtype=torch.bfloat16)
|
| 433 |
+
self._vae_encoder = components.video_vae_encoder
|
| 434 |
+
if self._vae_encoder is not None:
|
| 435 |
+
self._vae_encoder = self._vae_encoder.to(dtype=torch.bfloat16)
|
| 436 |
+
self._scheduler = components.scheduler
|
| 437 |
+
self._audio_vae = components.audio_vae_decoder
|
| 438 |
+
self._vocoder = components.vocoder
|
| 439 |
+
# Note: self._embeddings_processor was set in _load_text_encoder_and_cache_embeddings
|
| 440 |
+
|
| 441 |
+
# Determine initial dtype based on training mode.
|
| 442 |
+
# Note: For FSDP + LoRA, we'll cast to FP32 later in _prepare_models_for_training()
|
| 443 |
+
# after the accelerator is set up, and we can detect FSDP.
|
| 444 |
+
transformer_dtype = torch.bfloat16 if self._config.model.training_mode == "lora" else torch.float32
|
| 445 |
+
self._transformer = self._transformer.to(dtype=transformer_dtype)
|
| 446 |
+
|
| 447 |
+
if self._config.acceleration.quantization is not None:
|
| 448 |
+
if self._config.model.training_mode == "full":
|
| 449 |
+
raise ValueError("Quantization is not supported in full training mode.")
|
| 450 |
+
|
| 451 |
+
logger.info(f'Quantizing model with "{self._config.acceleration.quantization}". This may take a while...')
|
| 452 |
+
self._transformer = quantize_model(
|
| 453 |
+
self._transformer,
|
| 454 |
+
precision=self._config.acceleration.quantization,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Freeze all models. We later unfreeze the transformer based on training mode.
|
| 458 |
+
# Note: embedding_connectors are already frozen (they come from the frozen text encoder)
|
| 459 |
+
self._vae_decoder.requires_grad_(False)
|
| 460 |
+
if self._vae_encoder is not None:
|
| 461 |
+
self._vae_encoder.requires_grad_(False)
|
| 462 |
+
self._transformer.requires_grad_(False)
|
| 463 |
+
if self._audio_vae is not None:
|
| 464 |
+
self._audio_vae.requires_grad_(False)
|
| 465 |
+
if self._vocoder is not None:
|
| 466 |
+
self._vocoder.requires_grad_(False)
|
| 467 |
+
|
| 468 |
+
def _collect_trainable_params(self) -> None:
|
| 469 |
+
"""Collect trainable parameters based on training mode."""
|
| 470 |
+
if self._config.model.training_mode == "lora":
|
| 471 |
+
# For LoRA training, first set up LoRA layers
|
| 472 |
+
self._setup_lora()
|
| 473 |
+
elif self._config.model.training_mode == "full":
|
| 474 |
+
# For full training, unfreeze all transformer parameters
|
| 475 |
+
self._transformer.requires_grad_(True)
|
| 476 |
+
else:
|
| 477 |
+
raise ValueError(f"Unknown training mode: {self._config.model.training_mode}")
|
| 478 |
+
|
| 479 |
+
self._trainable_params = [p for p in self._transformer.parameters() if p.requires_grad]
|
| 480 |
+
logger.debug(f"Trainable params count: {sum(p.numel() for p in self._trainable_params):,}")
|
| 481 |
+
|
| 482 |
+
def _init_timestep_sampler(self) -> None:
|
| 483 |
+
"""Initialize the timestep sampler based on the config."""
|
| 484 |
+
sampler_cls = SAMPLERS[self._config.flow_matching.timestep_sampling_mode]
|
| 485 |
+
self._timestep_sampler = sampler_cls(**self._config.flow_matching.timestep_sampling_params)
|
| 486 |
+
|
| 487 |
+
def _setup_lora(self) -> None:
|
| 488 |
+
"""Configure LoRA adapters for the transformer. Only called in LoRA training mode."""
|
| 489 |
+
logger.debug(f"Adding LoRA adapter with rank {self._config.lora.rank}")
|
| 490 |
+
lora_config = LoraConfig(
|
| 491 |
+
r=self._config.lora.rank,
|
| 492 |
+
lora_alpha=self._config.lora.alpha,
|
| 493 |
+
target_modules=self._config.lora.target_modules,
|
| 494 |
+
lora_dropout=self._config.lora.dropout,
|
| 495 |
+
init_lora_weights=True,
|
| 496 |
+
)
|
| 497 |
+
# Wrap the transformer with PEFT to add LoRA layers
|
| 498 |
+
# noinspection PyTypeChecker
|
| 499 |
+
self._transformer = get_peft_model(self._transformer, lora_config)
|
| 500 |
+
|
| 501 |
+
def _load_checkpoint(self) -> None:
|
| 502 |
+
"""Load checkpoint if specified in config."""
|
| 503 |
+
if not self._config.model.load_checkpoint:
|
| 504 |
+
return
|
| 505 |
+
|
| 506 |
+
checkpoint_path = self._find_checkpoint(self._config.model.load_checkpoint)
|
| 507 |
+
if not checkpoint_path:
|
| 508 |
+
logger.warning(f"⚠️ Could not find checkpoint at {self._config.model.load_checkpoint}")
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
logger.info(f"📥 Loading checkpoint from {checkpoint_path}")
|
| 512 |
+
|
| 513 |
+
if self._config.model.training_mode == "full":
|
| 514 |
+
self._load_full_checkpoint(checkpoint_path)
|
| 515 |
+
else: # LoRA mode
|
| 516 |
+
self._load_lora_checkpoint(checkpoint_path)
|
| 517 |
+
|
| 518 |
+
def _load_full_checkpoint(self, checkpoint_path: Path) -> None:
|
| 519 |
+
"""Load full model checkpoint."""
|
| 520 |
+
state_dict = load_file(checkpoint_path)
|
| 521 |
+
self._transformer.load_state_dict(state_dict, strict=True)
|
| 522 |
+
|
| 523 |
+
logger.info("✅ Full model checkpoint loaded successfully")
|
| 524 |
+
|
| 525 |
+
def _load_lora_checkpoint(self, checkpoint_path: Path) -> None:
|
| 526 |
+
"""Load LoRA checkpoint with DDP/FSDP compatibility."""
|
| 527 |
+
state_dict = load_file(checkpoint_path)
|
| 528 |
+
|
| 529 |
+
# Adjust layer names to match internal format.
|
| 530 |
+
# (Weights are saved in ComfyUI-compatible format, with "diffusion_model." prefix)
|
| 531 |
+
state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()}
|
| 532 |
+
|
| 533 |
+
# Load LoRA weights and verify all weights were loaded
|
| 534 |
+
base_model = self._transformer.get_base_model()
|
| 535 |
+
set_peft_model_state_dict(base_model, state_dict)
|
| 536 |
+
|
| 537 |
+
logger.info("✅ LoRA checkpoint loaded successfully")
|
| 538 |
+
|
| 539 |
+
def _prepare_models_for_training(self) -> None:
|
| 540 |
+
"""Prepare models for training with Accelerate."""
|
| 541 |
+
|
| 542 |
+
# For FSDP + LoRA: Cast entire model to FP32.
|
| 543 |
+
# FSDP requires uniform dtype across all parameters in wrapped modules.
|
| 544 |
+
# In LoRA mode, PEFT creates LoRA params in FP32 while base model is BF16.
|
| 545 |
+
# We cast the base model to FP32 to match the LoRA params.
|
| 546 |
+
if self._accelerator.distributed_type == DistributedType.FSDP and self._config.model.training_mode == "lora":
|
| 547 |
+
logger.debug("FSDP: casting transformer to FP32 for uniform dtype")
|
| 548 |
+
self._transformer = self._transformer.to(dtype=torch.float32)
|
| 549 |
+
|
| 550 |
+
# Enable gradient checkpointing if requested
|
| 551 |
+
# For PeftModel, we need to access the underlying base model
|
| 552 |
+
transformer = (
|
| 553 |
+
self._transformer.get_base_model() if hasattr(self._transformer, "get_base_model") else self._transformer
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
transformer.set_gradient_checkpointing(self._config.optimization.enable_gradient_checkpointing)
|
| 557 |
+
|
| 558 |
+
# Keep frozen models on CPU for memory efficiency
|
| 559 |
+
self._vae_decoder = self._vae_decoder.to("cpu")
|
| 560 |
+
if self._vae_encoder is not None:
|
| 561 |
+
self._vae_encoder = self._vae_encoder.to("cpu")
|
| 562 |
+
|
| 563 |
+
# Embedding connectors are already on GPU from _load_text_encoder_and_cache_embeddings
|
| 564 |
+
|
| 565 |
+
# noinspection PyTypeChecker
|
| 566 |
+
self._transformer = self._accelerator.prepare(self._transformer)
|
| 567 |
+
|
| 568 |
+
# Log GPU memory usage after model preparation
|
| 569 |
+
vram_usage_gb = torch.cuda.memory_allocated() / 1024**3
|
| 570 |
+
logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB")
|
| 571 |
+
|
| 572 |
+
@staticmethod
|
| 573 |
+
def _find_checkpoint(checkpoint_path: str | Path) -> Path | None:
|
| 574 |
+
"""Find the checkpoint file to load, handling both file and directory paths."""
|
| 575 |
+
checkpoint_path = Path(checkpoint_path)
|
| 576 |
+
|
| 577 |
+
if checkpoint_path.is_file():
|
| 578 |
+
if not checkpoint_path.suffix == ".safetensors":
|
| 579 |
+
raise ValueError(f"Checkpoint file must have a .safetensors extension: {checkpoint_path}")
|
| 580 |
+
return checkpoint_path
|
| 581 |
+
|
| 582 |
+
if checkpoint_path.is_dir():
|
| 583 |
+
# Look for checkpoint files in the directory
|
| 584 |
+
checkpoints = list(checkpoint_path.rglob("*step_*.safetensors"))
|
| 585 |
+
|
| 586 |
+
if not checkpoints:
|
| 587 |
+
return None
|
| 588 |
+
|
| 589 |
+
# Sort by step number and return the latest
|
| 590 |
+
def _get_step_num(p: Path) -> int:
|
| 591 |
+
try:
|
| 592 |
+
return int(p.stem.split("step_")[1])
|
| 593 |
+
except (IndexError, ValueError):
|
| 594 |
+
return -1
|
| 595 |
+
|
| 596 |
+
latest = max(checkpoints, key=_get_step_num)
|
| 597 |
+
return latest
|
| 598 |
+
|
| 599 |
+
else:
|
| 600 |
+
raise ValueError(f"Invalid checkpoint path: {checkpoint_path}. Must be a file or directory.")
|
| 601 |
+
|
| 602 |
+
def _init_dataloader(self) -> None:
|
| 603 |
+
"""Initialize the training data loader using the strategy's data sources."""
|
| 604 |
+
if self._dataset is None:
|
| 605 |
+
# Get data sources from the training strategy
|
| 606 |
+
data_sources = self._training_strategy.get_data_sources()
|
| 607 |
+
|
| 608 |
+
self._dataset = PrecomputedDataset(self._config.data.preprocessed_data_root, data_sources=data_sources)
|
| 609 |
+
logger.debug(f"Loaded dataset with {len(self._dataset):,} samples from sources: {list(data_sources)}")
|
| 610 |
+
|
| 611 |
+
num_workers = self._config.data.num_dataloader_workers
|
| 612 |
+
dataloader = DataLoader(
|
| 613 |
+
self._dataset,
|
| 614 |
+
batch_size=self._config.optimization.batch_size,
|
| 615 |
+
shuffle=True,
|
| 616 |
+
drop_last=True,
|
| 617 |
+
num_workers=num_workers,
|
| 618 |
+
pin_memory=num_workers > 0,
|
| 619 |
+
persistent_workers=num_workers > 0,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
self._dataloader = self._accelerator.prepare(dataloader)
|
| 623 |
+
|
| 624 |
+
def _init_lora_weights(self) -> None:
|
| 625 |
+
"""Initialize LoRA weights for the transformer."""
|
| 626 |
+
logger.debug("Initializing LoRA weights...")
|
| 627 |
+
for _, module in self._transformer.named_modules():
|
| 628 |
+
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
| 629 |
+
module.reset_lora_parameters(adapter_name="default", init_lora_weights=True)
|
| 630 |
+
|
| 631 |
+
def _init_optimizer(self) -> None:
|
| 632 |
+
"""Initialize the optimizer and learning rate scheduler."""
|
| 633 |
+
opt_cfg = self._config.optimization
|
| 634 |
+
|
| 635 |
+
lr = opt_cfg.learning_rate
|
| 636 |
+
if opt_cfg.optimizer_type == "adamw":
|
| 637 |
+
optimizer = AdamW(self._trainable_params, lr=lr)
|
| 638 |
+
elif opt_cfg.optimizer_type == "adamw8bit":
|
| 639 |
+
# noinspection PyUnresolvedReferences
|
| 640 |
+
from bitsandbytes.optim import AdamW8bit # noqa: PLC0415
|
| 641 |
+
|
| 642 |
+
optimizer = AdamW8bit(self._trainable_params, lr=lr)
|
| 643 |
+
else:
|
| 644 |
+
raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}")
|
| 645 |
+
|
| 646 |
+
# Add scheduler initialization
|
| 647 |
+
lr_scheduler = self._create_scheduler(optimizer)
|
| 648 |
+
|
| 649 |
+
# noinspection PyTypeChecker
|
| 650 |
+
self._optimizer, self._lr_scheduler = self._accelerator.prepare(optimizer, lr_scheduler)
|
| 651 |
+
|
| 652 |
+
def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> LRScheduler | None:
|
| 653 |
+
"""Create learning rate scheduler based on config."""
|
| 654 |
+
scheduler_type = self._config.optimization.scheduler_type
|
| 655 |
+
steps = self._config.optimization.steps
|
| 656 |
+
params = self._config.optimization.scheduler_params or {}
|
| 657 |
+
|
| 658 |
+
if scheduler_type is None:
|
| 659 |
+
return None
|
| 660 |
+
|
| 661 |
+
if scheduler_type == "linear":
|
| 662 |
+
scheduler = LinearLR(
|
| 663 |
+
optimizer,
|
| 664 |
+
start_factor=params.pop("start_factor", 1.0),
|
| 665 |
+
end_factor=params.pop("end_factor", 0.1),
|
| 666 |
+
total_iters=steps,
|
| 667 |
+
**params,
|
| 668 |
+
)
|
| 669 |
+
elif scheduler_type == "cosine":
|
| 670 |
+
scheduler = CosineAnnealingLR(
|
| 671 |
+
optimizer,
|
| 672 |
+
T_max=steps,
|
| 673 |
+
eta_min=params.pop("eta_min", 0),
|
| 674 |
+
**params,
|
| 675 |
+
)
|
| 676 |
+
elif scheduler_type == "cosine_with_restarts":
|
| 677 |
+
scheduler = CosineAnnealingWarmRestarts(
|
| 678 |
+
optimizer,
|
| 679 |
+
T_0=params.pop("T_0", steps // 4), # First restart cycle length
|
| 680 |
+
T_mult=params.pop("T_mult", 1), # Multiplicative factor for cycle lengths
|
| 681 |
+
eta_min=params.pop("eta_min", 5e-5),
|
| 682 |
+
**params,
|
| 683 |
+
)
|
| 684 |
+
elif scheduler_type == "polynomial":
|
| 685 |
+
scheduler = PolynomialLR(
|
| 686 |
+
optimizer,
|
| 687 |
+
total_iters=steps,
|
| 688 |
+
power=params.pop("power", 1.0),
|
| 689 |
+
**params,
|
| 690 |
+
)
|
| 691 |
+
elif scheduler_type == "step":
|
| 692 |
+
scheduler = StepLR(
|
| 693 |
+
optimizer,
|
| 694 |
+
step_size=params.pop("step_size", steps // 2),
|
| 695 |
+
gamma=params.pop("gamma", 0.1),
|
| 696 |
+
**params,
|
| 697 |
+
)
|
| 698 |
+
elif scheduler_type == "constant":
|
| 699 |
+
scheduler = None
|
| 700 |
+
else:
|
| 701 |
+
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
|
| 702 |
+
|
| 703 |
+
return scheduler
|
| 704 |
+
|
| 705 |
+
def _setup_accelerator(self) -> None:
|
| 706 |
+
"""Initialize the Accelerator with the appropriate settings."""
|
| 707 |
+
|
| 708 |
+
# All distributed setup (DDP/FSDP, number of processes, etc.) is controlled by
|
| 709 |
+
# the user's Accelerate configuration (accelerate config / accelerate launch).
|
| 710 |
+
self._accelerator = Accelerator(
|
| 711 |
+
mixed_precision=self._config.acceleration.mixed_precision_mode,
|
| 712 |
+
gradient_accumulation_steps=self._config.optimization.gradient_accumulation_steps,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
if self._accelerator.num_processes > 1:
|
| 716 |
+
logger.info(
|
| 717 |
+
f"{self._accelerator.distributed_type.value} distributed training enabled "
|
| 718 |
+
f"with {self._accelerator.num_processes} processes"
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
local_batch = self._config.optimization.batch_size
|
| 722 |
+
global_batch = self._config.optimization.batch_size * self._accelerator.num_processes
|
| 723 |
+
logger.info(f"Local batch size: {local_batch}, global batch size: {global_batch}")
|
| 724 |
+
|
| 725 |
+
# Log torch.compile status from Accelerate's dynamo plugin
|
| 726 |
+
is_compile_enabled = (
|
| 727 |
+
hasattr(self._accelerator.state, "dynamo_plugin") and self._accelerator.state.dynamo_plugin.backend != "NO"
|
| 728 |
+
)
|
| 729 |
+
if is_compile_enabled:
|
| 730 |
+
plugin = self._accelerator.state.dynamo_plugin
|
| 731 |
+
logger.info(f"🔥 torch.compile enabled via Accelerate: backend={plugin.backend}, mode={plugin.mode}")
|
| 732 |
+
|
| 733 |
+
if self._accelerator.distributed_type == DistributedType.FSDP:
|
| 734 |
+
logger.warning(
|
| 735 |
+
"⚠️ FSDP + torch.compile is experimental and may hang on the first training iteration. "
|
| 736 |
+
"If this occurs, disable torch.compile by removing dynamo_config from your Accelerate config."
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
if self._accelerator.distributed_type == DistributedType.FSDP and self._config.acceleration.quantization:
|
| 740 |
+
logger.warning(
|
| 741 |
+
f"FSDP with quantization ({self._config.acceleration.quantization}) may have compatibility issues."
|
| 742 |
+
"Monitor training stability and consider disabling quantization if issues arise."
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation
|
| 746 |
+
@torch.no_grad()
|
| 747 |
+
@free_gpu_memory_context(after=True)
|
| 748 |
+
def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None:
|
| 749 |
+
"""Run validation by generating videos from validation prompts."""
|
| 750 |
+
use_images = self._config.validation.images is not None
|
| 751 |
+
use_reference_videos = self._config.validation.reference_videos is not None
|
| 752 |
+
generate_audio = self._config.validation.generate_audio
|
| 753 |
+
inference_steps = self._config.validation.inference_steps
|
| 754 |
+
|
| 755 |
+
# Zero gradients and free GPU memory to reclaim memory before validation sampling
|
| 756 |
+
self._optimizer.zero_grad(set_to_none=True)
|
| 757 |
+
free_gpu_memory()
|
| 758 |
+
|
| 759 |
+
# Start sampling progress tracking
|
| 760 |
+
sampling_ctx = progress.start_sampling(
|
| 761 |
+
num_prompts=len(self._config.validation.prompts),
|
| 762 |
+
num_steps=inference_steps,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
# Create validation sampler with loaded models and progress tracking
|
| 766 |
+
sampler = ValidationSampler(
|
| 767 |
+
transformer=self._transformer,
|
| 768 |
+
vae_decoder=self._vae_decoder,
|
| 769 |
+
vae_encoder=self._vae_encoder,
|
| 770 |
+
text_encoder=None,
|
| 771 |
+
audio_decoder=self._audio_vae if generate_audio else None,
|
| 772 |
+
vocoder=self._vocoder if generate_audio else None,
|
| 773 |
+
sampling_context=sampling_ctx,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
output_dir = Path(self._config.output_dir) / "samples"
|
| 777 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 778 |
+
|
| 779 |
+
video_paths = []
|
| 780 |
+
width, height, num_frames = self._config.validation.video_dims
|
| 781 |
+
|
| 782 |
+
for prompt_idx, prompt in enumerate(self._config.validation.prompts):
|
| 783 |
+
# Update progress to show current video
|
| 784 |
+
sampling_ctx.start_video(prompt_idx)
|
| 785 |
+
|
| 786 |
+
# Load conditioning image if provided
|
| 787 |
+
condition_image = None
|
| 788 |
+
if use_images:
|
| 789 |
+
image_path = self._config.validation.images[prompt_idx]
|
| 790 |
+
image = open_image_as_srgb(image_path)
|
| 791 |
+
# Convert PIL image to tensor [C, H, W] in [0, 1]
|
| 792 |
+
condition_image = F.to_tensor(image)
|
| 793 |
+
|
| 794 |
+
# Load reference video if provided (for IC-LoRA)
|
| 795 |
+
reference_video = None
|
| 796 |
+
if use_reference_videos:
|
| 797 |
+
ref_video_path = self._config.validation.reference_videos[prompt_idx]
|
| 798 |
+
# read_video returns [F, C, H, W] in [0, 1]
|
| 799 |
+
reference_video, _ = read_video(ref_video_path, max_frames=num_frames)
|
| 800 |
+
|
| 801 |
+
# Get cached embeddings for this prompt if available
|
| 802 |
+
cached_embeddings = (
|
| 803 |
+
self._cached_validation_embeddings[prompt_idx]
|
| 804 |
+
if self._cached_validation_embeddings is not None
|
| 805 |
+
else None
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
# Create generation config
|
| 809 |
+
gen_config = GenerationConfig(
|
| 810 |
+
prompt=prompt,
|
| 811 |
+
negative_prompt=self._config.validation.negative_prompt,
|
| 812 |
+
height=height,
|
| 813 |
+
width=width,
|
| 814 |
+
num_frames=num_frames,
|
| 815 |
+
frame_rate=self._config.validation.frame_rate,
|
| 816 |
+
num_inference_steps=inference_steps,
|
| 817 |
+
guidance_scale=self._config.validation.guidance_scale,
|
| 818 |
+
seed=self._config.validation.seed,
|
| 819 |
+
condition_image=condition_image,
|
| 820 |
+
reference_video=reference_video,
|
| 821 |
+
reference_downscale_factor=self._config.validation.reference_downscale_factor,
|
| 822 |
+
generate_audio=generate_audio,
|
| 823 |
+
include_reference_in_output=self._config.validation.include_reference_in_output,
|
| 824 |
+
cached_embeddings=cached_embeddings,
|
| 825 |
+
stg_scale=self._config.validation.stg_scale,
|
| 826 |
+
stg_blocks=self._config.validation.stg_blocks,
|
| 827 |
+
stg_mode=self._config.validation.stg_mode,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
# Generate sample
|
| 831 |
+
video, audio = sampler.generate(
|
| 832 |
+
config=gen_config,
|
| 833 |
+
device=self._accelerator.device,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
# Save output (image for single frame, video otherwise)
|
| 837 |
+
if IS_MAIN_PROCESS:
|
| 838 |
+
ext = "png" if num_frames == 1 else "mp4"
|
| 839 |
+
output_path = output_dir / f"step_{self._global_step:06d}_{prompt_idx + 1}.{ext}"
|
| 840 |
+
if num_frames == 1:
|
| 841 |
+
save_image(video, output_path)
|
| 842 |
+
else:
|
| 843 |
+
save_video(
|
| 844 |
+
video_tensor=video,
|
| 845 |
+
output_path=output_path,
|
| 846 |
+
fps=self._config.validation.frame_rate,
|
| 847 |
+
audio=audio,
|
| 848 |
+
audio_sample_rate=self._vocoder.output_sampling_rate if audio is not None else None,
|
| 849 |
+
)
|
| 850 |
+
video_paths.append(output_path)
|
| 851 |
+
|
| 852 |
+
# Clean up progress tasks
|
| 853 |
+
sampling_ctx.cleanup()
|
| 854 |
+
|
| 855 |
+
rel_outputs_path = output_dir.relative_to(self._config.output_dir)
|
| 856 |
+
logger.info(f"🎥 Validation samples for step {self._global_step} saved in {rel_outputs_path}")
|
| 857 |
+
return video_paths
|
| 858 |
+
|
| 859 |
+
@staticmethod
|
| 860 |
+
def _log_training_stats(stats: TrainingStats) -> None:
|
| 861 |
+
"""Log training statistics."""
|
| 862 |
+
stats_str = (
|
| 863 |
+
"📊 Training Statistics:\n"
|
| 864 |
+
f" - Total time: {stats.total_time_seconds / 60:.1f} minutes\n"
|
| 865 |
+
f" - Training speed: {stats.steps_per_second:.2f} steps/second\n"
|
| 866 |
+
f" - Samples/second: {stats.samples_per_second:.2f}\n"
|
| 867 |
+
f" - Peak GPU memory: {stats.peak_gpu_memory_gb:.2f} GB"
|
| 868 |
+
)
|
| 869 |
+
if stats.num_processes > 1:
|
| 870 |
+
stats_str += f"\n - Number of processes: {stats.num_processes}\n"
|
| 871 |
+
stats_str += f" - Global batch size: {stats.global_batch_size}"
|
| 872 |
+
logger.info(stats_str)
|
| 873 |
+
|
| 874 |
+
def _save_checkpoint(self) -> Path | None:
|
| 875 |
+
"""Save the model weights."""
|
| 876 |
+
is_lora = self._config.model.training_mode == "lora"
|
| 877 |
+
is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP
|
| 878 |
+
|
| 879 |
+
# Prepare paths
|
| 880 |
+
save_dir = Path(self._config.output_dir) / "checkpoints"
|
| 881 |
+
prefix = "lora" if is_lora else "model"
|
| 882 |
+
filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors"
|
| 883 |
+
saved_weights_path = save_dir / filename
|
| 884 |
+
|
| 885 |
+
# Get state dict (collective operation - all processes must participate)
|
| 886 |
+
self._accelerator.wait_for_everyone()
|
| 887 |
+
full_state_dict = self._accelerator.get_state_dict(self._transformer)
|
| 888 |
+
|
| 889 |
+
if not IS_MAIN_PROCESS:
|
| 890 |
+
return None
|
| 891 |
+
|
| 892 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
| 893 |
+
|
| 894 |
+
# Determine save precision
|
| 895 |
+
save_dtype = torch.bfloat16 if self._config.checkpoints.precision == "bfloat16" else torch.float32
|
| 896 |
+
|
| 897 |
+
# For LoRA: extract only adapter weights; for full: use as-is
|
| 898 |
+
if is_lora:
|
| 899 |
+
unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False)
|
| 900 |
+
# For FSDP, pass full_state_dict since model params aren't directly accessible
|
| 901 |
+
state_dict = get_peft_model_state_dict(unwrapped, state_dict=full_state_dict if is_fsdp else None)
|
| 902 |
+
|
| 903 |
+
# Remove "base_model.model." prefix added by PEFT
|
| 904 |
+
state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()}
|
| 905 |
+
|
| 906 |
+
# Convert to ComfyUI-compatible format (add "diffusion_model." prefix)
|
| 907 |
+
state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
|
| 908 |
+
|
| 909 |
+
# Cast to configured precision
|
| 910 |
+
state_dict = {k: v.to(save_dtype) if isinstance(v, Tensor) else v for k, v in state_dict.items()}
|
| 911 |
+
|
| 912 |
+
# Build metadata for safetensors file
|
| 913 |
+
metadata = self._build_checkpoint_metadata()
|
| 914 |
+
|
| 915 |
+
# Save to disk with metadata
|
| 916 |
+
save_file(state_dict, saved_weights_path, metadata=metadata)
|
| 917 |
+
else:
|
| 918 |
+
# Cast to configured precision
|
| 919 |
+
full_state_dict = {k: v.to(save_dtype) if isinstance(v, Tensor) else v for k, v in full_state_dict.items()}
|
| 920 |
+
|
| 921 |
+
# Save to disk
|
| 922 |
+
self._accelerator.save(full_state_dict, saved_weights_path)
|
| 923 |
+
|
| 924 |
+
rel_path = saved_weights_path.relative_to(self._config.output_dir)
|
| 925 |
+
logger.info(f"💾 {prefix.capitalize()} weights for step {self._global_step} saved in {rel_path}")
|
| 926 |
+
|
| 927 |
+
# Keep track of checkpoint paths, and cleanup old checkpoints if needed
|
| 928 |
+
self._checkpoint_paths.append(saved_weights_path)
|
| 929 |
+
self._cleanup_checkpoints()
|
| 930 |
+
return saved_weights_path
|
| 931 |
+
|
| 932 |
+
def _cleanup_checkpoints(self) -> None:
|
| 933 |
+
"""Clean up old checkpoints."""
|
| 934 |
+
if 0 < self._config.checkpoints.keep_last_n < len(self._checkpoint_paths):
|
| 935 |
+
checkpoints_to_remove = self._checkpoint_paths[: -self._config.checkpoints.keep_last_n]
|
| 936 |
+
for old_checkpoint in checkpoints_to_remove:
|
| 937 |
+
if old_checkpoint.exists():
|
| 938 |
+
old_checkpoint.unlink()
|
| 939 |
+
logger.info(f"Removed old checkpoints: {old_checkpoint}")
|
| 940 |
+
# Update the list to only contain kept checkpoints
|
| 941 |
+
self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :]
|
| 942 |
+
|
| 943 |
+
def _build_checkpoint_metadata(self) -> dict[str, str]:
|
| 944 |
+
"""Build metadata dictionary for safetensors checkpoint.
|
| 945 |
+
Delegates to the training strategy to get strategy-specific metadata
|
| 946 |
+
that downstream inference pipelines may need.
|
| 947 |
+
Returns:
|
| 948 |
+
Dictionary of string key-value pairs for safetensors metadata.
|
| 949 |
+
Values are converted to strings for safetensors compatibility.
|
| 950 |
+
"""
|
| 951 |
+
raw_metadata = self._training_strategy.get_checkpoint_metadata()
|
| 952 |
+
# Convert all values to strings for safetensors compatibility
|
| 953 |
+
metadata = {k: str(v) for k, v in raw_metadata.items()}
|
| 954 |
+
if metadata:
|
| 955 |
+
logger.info(f"Saving checkpoint metadata: {metadata}")
|
| 956 |
+
return metadata
|
| 957 |
+
|
| 958 |
+
def _save_config(self) -> None:
|
| 959 |
+
"""Save the training configuration as a YAML file in the output directory."""
|
| 960 |
+
if not IS_MAIN_PROCESS:
|
| 961 |
+
return
|
| 962 |
+
|
| 963 |
+
config_path = Path(self._config.output_dir) / "training_config.yaml"
|
| 964 |
+
with open(config_path, "w") as f:
|
| 965 |
+
yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2)
|
| 966 |
+
|
| 967 |
+
logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}")
|
| 968 |
+
|
| 969 |
+
def _init_wandb(self) -> None:
|
| 970 |
+
"""Initialize Weights & Biases run."""
|
| 971 |
+
if not self._config.wandb.enabled or not IS_MAIN_PROCESS:
|
| 972 |
+
self._wandb_run = None
|
| 973 |
+
return
|
| 974 |
+
|
| 975 |
+
wandb_config = self._config.wandb
|
| 976 |
+
run = wandb.init(
|
| 977 |
+
project=wandb_config.project,
|
| 978 |
+
entity=wandb_config.entity,
|
| 979 |
+
name=Path(self._config.output_dir).name,
|
| 980 |
+
tags=wandb_config.tags,
|
| 981 |
+
config=self._config.model_dump(),
|
| 982 |
+
)
|
| 983 |
+
self._wandb_run = run
|
| 984 |
+
|
| 985 |
+
def _log_metrics(self, metrics: dict[str, float]) -> None:
|
| 986 |
+
"""Log metrics to Weights & Biases."""
|
| 987 |
+
if self._wandb_run is not None:
|
| 988 |
+
self._wandb_run.log(metrics)
|
| 989 |
+
|
| 990 |
+
def _log_validation_samples(self, sample_paths: list[Path], prompts: list[str]) -> None:
|
| 991 |
+
"""Log validation samples (videos or images) to Weights & Biases."""
|
| 992 |
+
if not self._config.wandb.log_validation_videos or self._wandb_run is None:
|
| 993 |
+
return
|
| 994 |
+
|
| 995 |
+
# Determine if outputs are images or videos based on file extension
|
| 996 |
+
is_image = sample_paths and sample_paths[0].suffix.lower() in (".png", ".jpg", ".jpeg", ".heic", ".webp")
|
| 997 |
+
media_cls = wandb.Image if is_image else wandb.Video
|
| 998 |
+
|
| 999 |
+
samples = [media_cls(str(path), caption=prompt) for path, prompt in zip(sample_paths, prompts, strict=True)]
|
| 1000 |
+
self._wandb_run.log({"validation_samples": samples}, step=self._global_step)
|
packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training strategies for different conditioning modes.
|
| 2 |
+
This package implements the Strategy Pattern to handle different training modes:
|
| 3 |
+
- Text-to-video training (standard generation, optionally with audio)
|
| 4 |
+
- Video-to-video training (IC-LoRA mode with reference videos)
|
| 5 |
+
Each strategy encapsulates the specific logic for preparing model inputs and computing loss.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from ltx_trainer import logger
|
| 9 |
+
from ltx_trainer.training_strategies.base_strategy import (
|
| 10 |
+
DEFAULT_FPS,
|
| 11 |
+
VIDEO_SCALE_FACTORS,
|
| 12 |
+
ModelInputs,
|
| 13 |
+
TrainingStrategy,
|
| 14 |
+
TrainingStrategyConfigBase,
|
| 15 |
+
)
|
| 16 |
+
from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig, TextToVideoStrategy
|
| 17 |
+
from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig, VideoToVideoStrategy
|
| 18 |
+
|
| 19 |
+
# Type alias for all strategy config types
|
| 20 |
+
TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"DEFAULT_FPS",
|
| 24 |
+
"VIDEO_SCALE_FACTORS",
|
| 25 |
+
"ModelInputs",
|
| 26 |
+
"TextToVideoConfig",
|
| 27 |
+
"TextToVideoStrategy",
|
| 28 |
+
"TrainingStrategy",
|
| 29 |
+
"TrainingStrategyConfig",
|
| 30 |
+
"TrainingStrategyConfigBase",
|
| 31 |
+
"VideoToVideoConfig",
|
| 32 |
+
"VideoToVideoStrategy",
|
| 33 |
+
"get_training_strategy",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy:
|
| 38 |
+
"""Factory function to create the appropriate training strategy.
|
| 39 |
+
The strategy is determined by the `name` field in the configuration.
|
| 40 |
+
Args:
|
| 41 |
+
config: Strategy-specific configuration with a `name` field
|
| 42 |
+
Returns:
|
| 43 |
+
The appropriate training strategy instance
|
| 44 |
+
Raises:
|
| 45 |
+
ValueError: If strategy name is not supported
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
match config:
|
| 49 |
+
case TextToVideoConfig():
|
| 50 |
+
strategy = TextToVideoStrategy(config)
|
| 51 |
+
case VideoToVideoConfig():
|
| 52 |
+
strategy = VideoToVideoStrategy(config)
|
| 53 |
+
case _:
|
| 54 |
+
raise ValueError(f"Unknown training strategy config type: {type(config).__name__}")
|
| 55 |
+
|
| 56 |
+
audio_mode = "(audio enabled 🔈)" if getattr(config, "with_audio", False) else "(audio disabled 🔇)"
|
| 57 |
+
logger.debug(f"🎯 Using {strategy.__class__.__name__} training strategy {audio_mode}")
|
| 58 |
+
return strategy
|
packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base class for training strategies.
|
| 2 |
+
This module defines the abstract base class that all training strategies must implement,
|
| 3 |
+
along with the base configuration class.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Literal
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from ltx_core.components.patchifiers import (
|
| 16 |
+
AudioPatchifier,
|
| 17 |
+
VideoLatentPatchifier,
|
| 18 |
+
get_pixel_coords,
|
| 19 |
+
)
|
| 20 |
+
from ltx_core.model.transformer.modality import Modality
|
| 21 |
+
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
|
| 22 |
+
from ltx_trainer.timestep_samplers import TimestepSampler
|
| 23 |
+
|
| 24 |
+
# Default frames per second for video missing in the FPS metadata
|
| 25 |
+
DEFAULT_FPS = 24
|
| 26 |
+
|
| 27 |
+
# VAE scale factors for LTX-2
|
| 28 |
+
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TrainingStrategyConfigBase(BaseModel):
|
| 32 |
+
"""Base configuration class for training strategies.
|
| 33 |
+
All strategy-specific configuration classes should inherit from this.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
model_config = ConfigDict(extra="forbid")
|
| 37 |
+
|
| 38 |
+
name: Literal["text_to_video", "video_to_video"] = Field(
|
| 39 |
+
description="Unique name identifying the training strategy type"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class ModelInputs:
|
| 45 |
+
"""Container for model inputs using the Modality-based interface."""
|
| 46 |
+
|
| 47 |
+
video: Modality
|
| 48 |
+
audio: Modality | None
|
| 49 |
+
|
| 50 |
+
# Training targets (for loss computation)
|
| 51 |
+
video_targets: Tensor
|
| 52 |
+
audio_targets: Tensor | None
|
| 53 |
+
|
| 54 |
+
# Masks for loss computation
|
| 55 |
+
video_loss_mask: Tensor # Boolean mask: True = compute loss for this token
|
| 56 |
+
audio_loss_mask: Tensor | None
|
| 57 |
+
|
| 58 |
+
# Metadata needed for loss computation in some strategies
|
| 59 |
+
ref_seq_len: int | None = None # For IC-LoRA: length of reference sequence
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TrainingStrategy(ABC):
|
| 63 |
+
"""Abstract base class for training strategies.
|
| 64 |
+
Each strategy encapsulates the logic for a specific training mode,
|
| 65 |
+
handling input preparation and loss computation.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, config: TrainingStrategyConfigBase):
|
| 69 |
+
"""Initialize strategy with configuration.
|
| 70 |
+
Args:
|
| 71 |
+
config: Strategy-specific configuration
|
| 72 |
+
"""
|
| 73 |
+
self.config = config
|
| 74 |
+
self._video_patchifier = VideoLatentPatchifier(patch_size=1)
|
| 75 |
+
self._audio_patchifier = AudioPatchifier(patch_size=1)
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def requires_audio(self) -> bool:
|
| 79 |
+
"""Whether this training strategy requires audio components.
|
| 80 |
+
Override this property in subclasses that support audio training.
|
| 81 |
+
The trainer uses this to determine whether to load audio VAE and vocoder.
|
| 82 |
+
Returns:
|
| 83 |
+
True if audio components should be loaded, False otherwise.
|
| 84 |
+
"""
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
@abstractmethod
|
| 88 |
+
def get_data_sources(self) -> list[str] | dict[str, str]:
|
| 89 |
+
"""Get the required data sources for this training strategy.
|
| 90 |
+
Returns:
|
| 91 |
+
Either a list of data directory names (where output keys match directory names)
|
| 92 |
+
or a dictionary mapping data directory names to custom output keys for the dataset
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
@abstractmethod
|
| 96 |
+
def prepare_training_inputs(
|
| 97 |
+
self,
|
| 98 |
+
batch: dict[str, Any],
|
| 99 |
+
timestep_sampler: TimestepSampler,
|
| 100 |
+
) -> ModelInputs:
|
| 101 |
+
"""Prepare training inputs from a raw data batch.
|
| 102 |
+
Args:
|
| 103 |
+
batch: Raw batch data from the dataset. Contains:
|
| 104 |
+
- "latents": Video latent data
|
| 105 |
+
- "conditions": Text embeddings with keys:
|
| 106 |
+
- "video_prompt_embeds": Already processed by embedding connectors
|
| 107 |
+
- "audio_prompt_embeds": Already processed by embedding connectors
|
| 108 |
+
- "prompt_attention_mask": Attention mask
|
| 109 |
+
- Additional keys depending on strategy (e.g., "ref_latents" for IC-LoRA)
|
| 110 |
+
timestep_sampler: Sampler for generating timesteps and noise
|
| 111 |
+
Returns:
|
| 112 |
+
ModelInputs containing Modality objects and training targets
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
@abstractmethod
|
| 116 |
+
def compute_loss(
|
| 117 |
+
self,
|
| 118 |
+
video_pred: Tensor,
|
| 119 |
+
audio_pred: Tensor | None,
|
| 120 |
+
inputs: ModelInputs,
|
| 121 |
+
) -> Tensor:
|
| 122 |
+
"""Compute the training loss.
|
| 123 |
+
Args:
|
| 124 |
+
video_pred: Video prediction from the transformer model
|
| 125 |
+
audio_pred: Audio prediction from the transformer model (None for video-only)
|
| 126 |
+
inputs: The prepared model inputs containing targets and masks
|
| 127 |
+
Returns:
|
| 128 |
+
Scalar loss tensor
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def get_checkpoint_metadata(self) -> dict[str, Any]:
|
| 132 |
+
"""Get strategy-specific metadata to include in checkpoint files.
|
| 133 |
+
Override this method in subclasses to add custom metadata,
|
| 134 |
+
e.g. any parameters that a downstream inference pipeline may need.
|
| 135 |
+
Returns:
|
| 136 |
+
Dictionary of metadata key-value pairs (values must be JSON-serializable)
|
| 137 |
+
"""
|
| 138 |
+
return {}
|
| 139 |
+
|
| 140 |
+
def _get_video_positions(
|
| 141 |
+
self,
|
| 142 |
+
num_frames: int,
|
| 143 |
+
height: int,
|
| 144 |
+
width: int,
|
| 145 |
+
batch_size: int,
|
| 146 |
+
fps: float,
|
| 147 |
+
device: torch.device,
|
| 148 |
+
dtype: torch.dtype,
|
| 149 |
+
) -> Tensor:
|
| 150 |
+
"""Generate video position embeddings using ltx_core's native implementation.
|
| 151 |
+
Args:
|
| 152 |
+
num_frames: Number of latent frames
|
| 153 |
+
height: Latent height
|
| 154 |
+
width: Latent width
|
| 155 |
+
batch_size: Batch size
|
| 156 |
+
fps: Frames per second
|
| 157 |
+
device: Target device
|
| 158 |
+
dtype: Target dtype
|
| 159 |
+
Returns:
|
| 160 |
+
Position tensor of shape [B, 3, seq_len, 2]
|
| 161 |
+
"""
|
| 162 |
+
latent_coords = self._video_patchifier.get_patch_grid_bounds(
|
| 163 |
+
output_shape=VideoLatentShape(
|
| 164 |
+
frames=num_frames,
|
| 165 |
+
height=height,
|
| 166 |
+
width=width,
|
| 167 |
+
batch=batch_size,
|
| 168 |
+
channels=128, # Video latent channels
|
| 169 |
+
),
|
| 170 |
+
device=device,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Convert latent coords to pixel coords with causal fix
|
| 174 |
+
pixel_coords = get_pixel_coords(
|
| 175 |
+
latent_coords=latent_coords,
|
| 176 |
+
scale_factors=VIDEO_SCALE_FACTORS,
|
| 177 |
+
causal_fix=True,
|
| 178 |
+
).to(dtype)
|
| 179 |
+
|
| 180 |
+
# Scale temporal dimension by 1/fps to get time in seconds
|
| 181 |
+
pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
|
| 182 |
+
|
| 183 |
+
return pixel_coords
|
| 184 |
+
|
| 185 |
+
def _get_audio_positions(
|
| 186 |
+
self,
|
| 187 |
+
num_time_steps: int,
|
| 188 |
+
batch_size: int,
|
| 189 |
+
device: torch.device,
|
| 190 |
+
dtype: torch.dtype,
|
| 191 |
+
) -> Tensor:
|
| 192 |
+
"""Generate audio position embeddings using ltx_core's native implementation.
|
| 193 |
+
Args:
|
| 194 |
+
num_time_steps: Number of audio time steps (T, not T*mel_bins)
|
| 195 |
+
batch_size: Batch size
|
| 196 |
+
device: Target device
|
| 197 |
+
dtype: Target dtype
|
| 198 |
+
Returns:
|
| 199 |
+
Position tensor of shape [B, 1, num_time_steps, 2]
|
| 200 |
+
Note:
|
| 201 |
+
Audio latents should be in patchified format [B, T, C*F] = [B, T, 128]
|
| 202 |
+
where T is the number of time steps, C=8 channels, F=16 mel bins.
|
| 203 |
+
This matches the format produced by AudioPatchifier.patchify().
|
| 204 |
+
"""
|
| 205 |
+
mel_bins = 16
|
| 206 |
+
|
| 207 |
+
latent_coords = self._audio_patchifier.get_patch_grid_bounds(
|
| 208 |
+
output_shape=AudioLatentShape(
|
| 209 |
+
frames=num_time_steps,
|
| 210 |
+
mel_bins=mel_bins,
|
| 211 |
+
batch=batch_size,
|
| 212 |
+
channels=8, # Audio latent channels
|
| 213 |
+
),
|
| 214 |
+
device=device,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return latent_coords.to(dtype)
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def _create_per_token_timesteps(conditioning_mask: Tensor, sampled_sigma: Tensor) -> Tensor:
|
| 221 |
+
"""Create per-token timesteps based on conditioning mask.
|
| 222 |
+
Args:
|
| 223 |
+
conditioning_mask: Boolean mask of shape (batch_size, sequence_length),
|
| 224 |
+
where True = conditioning token (timestep=0), False = target token (use sigma)
|
| 225 |
+
sampled_sigma: Sampled sigma values of shape (batch_size,) or (batch_size, 1, 1)
|
| 226 |
+
Returns:
|
| 227 |
+
Timesteps tensor of shape [batch_size, sequence_length]
|
| 228 |
+
"""
|
| 229 |
+
# Expand to match conditioning mask shape [B, seq_len]
|
| 230 |
+
expanded_sigma = sampled_sigma.view(-1, 1).expand_as(conditioning_mask)
|
| 231 |
+
|
| 232 |
+
# Conditioning tokens get 0, target tokens get the sampled sigma
|
| 233 |
+
return torch.where(conditioning_mask, torch.zeros_like(expanded_sigma), expanded_sigma)
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def _create_first_frame_conditioning_mask(
|
| 237 |
+
batch_size: int,
|
| 238 |
+
sequence_length: int,
|
| 239 |
+
height: int,
|
| 240 |
+
width: int,
|
| 241 |
+
device: torch.device,
|
| 242 |
+
first_frame_conditioning_p: float = 0.0,
|
| 243 |
+
) -> Tensor:
|
| 244 |
+
"""Create conditioning mask for first frame conditioning.
|
| 245 |
+
Args:
|
| 246 |
+
batch_size: Batch size
|
| 247 |
+
sequence_length: Total sequence length
|
| 248 |
+
height: Latent height
|
| 249 |
+
width: Latent width
|
| 250 |
+
device: Target device
|
| 251 |
+
first_frame_conditioning_p: Probability of conditioning on the first frame
|
| 252 |
+
Returns:
|
| 253 |
+
Boolean mask where True indicates first frame tokens (if conditioning is enabled)
|
| 254 |
+
"""
|
| 255 |
+
conditioning_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
|
| 256 |
+
|
| 257 |
+
if first_frame_conditioning_p > 0 and random.random() < first_frame_conditioning_p:
|
| 258 |
+
first_frame_end_idx = height * width
|
| 259 |
+
if first_frame_end_idx < sequence_length:
|
| 260 |
+
conditioning_mask[:, :first_frame_end_idx] = True
|
| 261 |
+
|
| 262 |
+
return conditioning_mask
|
packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text-to-video training strategy.
|
| 2 |
+
This strategy implements standard text-to-video generation training where:
|
| 3 |
+
- Only target latents are used (no reference videos)
|
| 4 |
+
- Standard noise application and loss computation
|
| 5 |
+
- Supports first frame conditioning
|
| 6 |
+
- Optionally supports joint audio-video training
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Literal
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from pydantic import Field
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from ltx_core.model.transformer.modality import Modality
|
| 16 |
+
from ltx_trainer import logger
|
| 17 |
+
from ltx_trainer.timestep_samplers import TimestepSampler
|
| 18 |
+
from ltx_trainer.training_strategies.base_strategy import (
|
| 19 |
+
DEFAULT_FPS,
|
| 20 |
+
ModelInputs,
|
| 21 |
+
TrainingStrategy,
|
| 22 |
+
TrainingStrategyConfigBase,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TextToVideoConfig(TrainingStrategyConfigBase):
|
| 27 |
+
"""Configuration for text-to-video training strategy."""
|
| 28 |
+
|
| 29 |
+
name: Literal["text_to_video"] = "text_to_video"
|
| 30 |
+
|
| 31 |
+
first_frame_conditioning_p: float = Field(
|
| 32 |
+
default=0.1,
|
| 33 |
+
description="Probability of conditioning on the first frame during training",
|
| 34 |
+
ge=0.0,
|
| 35 |
+
le=1.0,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
with_audio: bool = Field(
|
| 39 |
+
default=False,
|
| 40 |
+
description="Whether to include audio in training (joint audio-video generation)",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
audio_latents_dir: str = Field(
|
| 44 |
+
default="audio_latents",
|
| 45 |
+
description="Directory name for audio latents when with_audio is True",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class TextToVideoStrategy(TrainingStrategy):
|
| 50 |
+
"""Text-to-video training strategy.
|
| 51 |
+
This strategy implements regular video generation training where:
|
| 52 |
+
- Only target latents are used (no reference videos)
|
| 53 |
+
- Standard noise application and loss computation
|
| 54 |
+
- Supports first frame conditioning
|
| 55 |
+
- Optionally supports joint audio-video training when with_audio=True
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
config: TextToVideoConfig
|
| 59 |
+
|
| 60 |
+
def __init__(self, config: TextToVideoConfig):
|
| 61 |
+
"""Initialize strategy with configuration.
|
| 62 |
+
Args:
|
| 63 |
+
config: Text-to-video configuration
|
| 64 |
+
"""
|
| 65 |
+
super().__init__(config)
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def requires_audio(self) -> bool:
|
| 69 |
+
"""Whether this training strategy requires audio components."""
|
| 70 |
+
return self.config.with_audio
|
| 71 |
+
|
| 72 |
+
def get_data_sources(self) -> list[str] | dict[str, str]:
|
| 73 |
+
"""
|
| 74 |
+
Text-to-video training requires latents and text conditions.
|
| 75 |
+
When with_audio is True, also requires audio latents.
|
| 76 |
+
"""
|
| 77 |
+
sources = {
|
| 78 |
+
"latents": "latents",
|
| 79 |
+
"conditions": "conditions",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if self.config.with_audio:
|
| 83 |
+
sources[self.config.audio_latents_dir] = "audio_latents"
|
| 84 |
+
|
| 85 |
+
return sources
|
| 86 |
+
|
| 87 |
+
def prepare_training_inputs(
|
| 88 |
+
self,
|
| 89 |
+
batch: dict[str, Any],
|
| 90 |
+
timestep_sampler: TimestepSampler,
|
| 91 |
+
) -> ModelInputs:
|
| 92 |
+
"""Prepare inputs for text-to-video training."""
|
| 93 |
+
# Get pre-encoded latents - dataset provides uniform non-patchified format [B, C, F, H, W]
|
| 94 |
+
latents = batch["latents"]
|
| 95 |
+
video_latents = latents["latents"]
|
| 96 |
+
|
| 97 |
+
# Get video dimensions (assume same for all batch elements)
|
| 98 |
+
num_frames = latents["num_frames"][0].item()
|
| 99 |
+
height = latents["height"][0].item()
|
| 100 |
+
width = latents["width"][0].item()
|
| 101 |
+
|
| 102 |
+
# Patchify latents: [B, C, F, H, W] -> [B, seq_len, C]
|
| 103 |
+
video_latents = self._video_patchifier.patchify(video_latents)
|
| 104 |
+
|
| 105 |
+
# Handle FPS with backward compatibility
|
| 106 |
+
fps = latents.get("fps", None)
|
| 107 |
+
if fps is not None and not torch.all(fps == fps[0]):
|
| 108 |
+
logger.warning(
|
| 109 |
+
f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}"
|
| 110 |
+
)
|
| 111 |
+
fps = fps[0].item() if fps is not None else DEFAULT_FPS
|
| 112 |
+
|
| 113 |
+
# Get text embeddings (already processed by embedding connectors in trainer)
|
| 114 |
+
conditions = batch["conditions"]
|
| 115 |
+
video_prompt_embeds = conditions["video_prompt_embeds"]
|
| 116 |
+
audio_prompt_embeds = conditions["audio_prompt_embeds"]
|
| 117 |
+
prompt_attention_mask = conditions["prompt_attention_mask"]
|
| 118 |
+
|
| 119 |
+
batch_size = video_latents.shape[0]
|
| 120 |
+
video_seq_len = video_latents.shape[1]
|
| 121 |
+
device = video_latents.device
|
| 122 |
+
dtype = video_latents.dtype
|
| 123 |
+
|
| 124 |
+
# Create conditioning mask (first frame conditioning)
|
| 125 |
+
video_conditioning_mask = self._create_first_frame_conditioning_mask(
|
| 126 |
+
batch_size=batch_size,
|
| 127 |
+
sequence_length=video_seq_len,
|
| 128 |
+
height=height,
|
| 129 |
+
width=width,
|
| 130 |
+
device=device,
|
| 131 |
+
first_frame_conditioning_p=self.config.first_frame_conditioning_p,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Sample noise and sigmas
|
| 135 |
+
sigmas = timestep_sampler.sample_for(video_latents)
|
| 136 |
+
video_noise = torch.randn_like(video_latents)
|
| 137 |
+
|
| 138 |
+
# Apply noise: noisy = (1 - sigma) * clean + sigma * noise
|
| 139 |
+
sigmas_expanded = sigmas.view(-1, 1, 1)
|
| 140 |
+
noisy_video = (1 - sigmas_expanded) * video_latents + sigmas_expanded * video_noise
|
| 141 |
+
|
| 142 |
+
# For conditioning tokens, use clean latents
|
| 143 |
+
conditioning_mask_expanded = video_conditioning_mask.unsqueeze(-1)
|
| 144 |
+
noisy_video = torch.where(conditioning_mask_expanded, video_latents, noisy_video)
|
| 145 |
+
|
| 146 |
+
# Compute video targets (velocity prediction)
|
| 147 |
+
video_targets = video_noise - video_latents
|
| 148 |
+
|
| 149 |
+
# Create per-token timesteps
|
| 150 |
+
video_timesteps = self._create_per_token_timesteps(video_conditioning_mask, sigmas.squeeze())
|
| 151 |
+
|
| 152 |
+
# Generate video positions using ltx_core's native implementation
|
| 153 |
+
video_positions = self._get_video_positions(
|
| 154 |
+
num_frames=num_frames,
|
| 155 |
+
height=height,
|
| 156 |
+
width=width,
|
| 157 |
+
batch_size=batch_size,
|
| 158 |
+
fps=fps,
|
| 159 |
+
device=device,
|
| 160 |
+
dtype=dtype,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Create video Modality
|
| 164 |
+
video_modality = Modality(
|
| 165 |
+
enabled=True,
|
| 166 |
+
sigma=sigmas,
|
| 167 |
+
latent=noisy_video,
|
| 168 |
+
timesteps=video_timesteps,
|
| 169 |
+
positions=video_positions,
|
| 170 |
+
context=video_prompt_embeds,
|
| 171 |
+
context_mask=prompt_attention_mask,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Video loss mask: True for tokens we want to compute loss on (non-conditioning tokens)
|
| 175 |
+
video_loss_mask = ~video_conditioning_mask
|
| 176 |
+
|
| 177 |
+
# Handle audio if enabled
|
| 178 |
+
audio_modality = None
|
| 179 |
+
audio_targets = None
|
| 180 |
+
audio_loss_mask = None
|
| 181 |
+
|
| 182 |
+
if self.config.with_audio:
|
| 183 |
+
audio_modality, audio_targets, audio_loss_mask = self._prepare_audio_inputs(
|
| 184 |
+
batch=batch,
|
| 185 |
+
sigmas=sigmas,
|
| 186 |
+
audio_prompt_embeds=audio_prompt_embeds,
|
| 187 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 188 |
+
batch_size=batch_size,
|
| 189 |
+
device=device,
|
| 190 |
+
dtype=dtype,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return ModelInputs(
|
| 194 |
+
video=video_modality,
|
| 195 |
+
audio=audio_modality,
|
| 196 |
+
video_targets=video_targets,
|
| 197 |
+
audio_targets=audio_targets,
|
| 198 |
+
video_loss_mask=video_loss_mask,
|
| 199 |
+
audio_loss_mask=audio_loss_mask,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def _prepare_audio_inputs(
|
| 203 |
+
self,
|
| 204 |
+
batch: dict[str, Any],
|
| 205 |
+
sigmas: Tensor,
|
| 206 |
+
audio_prompt_embeds: Tensor,
|
| 207 |
+
prompt_attention_mask: Tensor,
|
| 208 |
+
batch_size: int,
|
| 209 |
+
device: torch.device,
|
| 210 |
+
dtype: torch.dtype,
|
| 211 |
+
) -> tuple[Modality, Tensor, Tensor]:
|
| 212 |
+
"""Prepare audio inputs for joint audio-video training.
|
| 213 |
+
Args:
|
| 214 |
+
batch: Raw batch data containing audio_latents
|
| 215 |
+
sigmas: Sampled sigma values (same as video)
|
| 216 |
+
audio_prompt_embeds: Audio context embeddings
|
| 217 |
+
prompt_attention_mask: Attention mask for context
|
| 218 |
+
batch_size: Batch size
|
| 219 |
+
device: Target device
|
| 220 |
+
dtype: Target dtype
|
| 221 |
+
Returns:
|
| 222 |
+
Tuple of (audio_modality, audio_targets, audio_loss_mask)
|
| 223 |
+
"""
|
| 224 |
+
# Get audio latents - dataset provides uniform non-patchified format [B, C, T, F]
|
| 225 |
+
audio_data = batch["audio_latents"]
|
| 226 |
+
audio_latents = audio_data["latents"]
|
| 227 |
+
|
| 228 |
+
# Patchify audio latents: [B, C, T, F] -> [B, T, C*F]
|
| 229 |
+
audio_latents = self._audio_patchifier.patchify(audio_latents)
|
| 230 |
+
|
| 231 |
+
audio_seq_len = audio_latents.shape[1]
|
| 232 |
+
|
| 233 |
+
# Sample audio noise
|
| 234 |
+
audio_noise = torch.randn_like(audio_latents)
|
| 235 |
+
|
| 236 |
+
# Apply noise to audio (same sigma as video)
|
| 237 |
+
sigmas_expanded = sigmas.view(-1, 1, 1)
|
| 238 |
+
noisy_audio = (1 - sigmas_expanded) * audio_latents + sigmas_expanded * audio_noise
|
| 239 |
+
|
| 240 |
+
# Compute audio targets
|
| 241 |
+
audio_targets = audio_noise - audio_latents
|
| 242 |
+
|
| 243 |
+
# Audio timesteps: all tokens use the sampled sigma (no conditioning mask)
|
| 244 |
+
audio_timesteps = sigmas.view(-1, 1).expand(-1, audio_seq_len)
|
| 245 |
+
|
| 246 |
+
# Generate audio positions
|
| 247 |
+
audio_positions = self._get_audio_positions(
|
| 248 |
+
num_time_steps=audio_seq_len,
|
| 249 |
+
batch_size=batch_size,
|
| 250 |
+
device=device,
|
| 251 |
+
dtype=dtype,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Create audio Modality
|
| 255 |
+
audio_modality = Modality(
|
| 256 |
+
enabled=True,
|
| 257 |
+
latent=noisy_audio,
|
| 258 |
+
sigma=sigmas,
|
| 259 |
+
timesteps=audio_timesteps,
|
| 260 |
+
positions=audio_positions,
|
| 261 |
+
context=audio_prompt_embeds,
|
| 262 |
+
context_mask=prompt_attention_mask,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Audio loss mask: all tokens contribute to loss (no conditioning)
|
| 266 |
+
audio_loss_mask = torch.ones(batch_size, audio_seq_len, dtype=torch.bool, device=device)
|
| 267 |
+
|
| 268 |
+
return audio_modality, audio_targets, audio_loss_mask
|
| 269 |
+
|
| 270 |
+
def compute_loss(
|
| 271 |
+
self,
|
| 272 |
+
video_pred: Tensor,
|
| 273 |
+
audio_pred: Tensor | None,
|
| 274 |
+
inputs: ModelInputs,
|
| 275 |
+
) -> Tensor:
|
| 276 |
+
"""Compute masked MSE loss for video and optionally audio."""
|
| 277 |
+
# Video loss
|
| 278 |
+
video_loss = (video_pred - inputs.video_targets).pow(2)
|
| 279 |
+
video_loss_mask = inputs.video_loss_mask.unsqueeze(-1).float()
|
| 280 |
+
video_loss = video_loss.mul(video_loss_mask).div(video_loss_mask.mean())
|
| 281 |
+
video_loss = video_loss.mean()
|
| 282 |
+
|
| 283 |
+
# If no audio, return video loss only
|
| 284 |
+
if not self.config.with_audio or audio_pred is None or inputs.audio_targets is None:
|
| 285 |
+
return video_loss
|
| 286 |
+
|
| 287 |
+
# Audio loss (no conditioning mask)
|
| 288 |
+
audio_loss = (audio_pred - inputs.audio_targets).pow(2).mean()
|
| 289 |
+
|
| 290 |
+
# Combined loss
|
| 291 |
+
return video_loss + audio_loss
|
packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video-to-video training strategy for IC-LoRA.
|
| 2 |
+
This strategy implements training with reference video conditioning where:
|
| 3 |
+
- Reference latents (clean) are concatenated with target latents (noised)
|
| 4 |
+
- Video coordinates handle both reference and target sequences
|
| 5 |
+
- Loss is computed only on the target portion
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Literal
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from pydantic import Field
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
from ltx_core.model.transformer.modality import Modality
|
| 15 |
+
from ltx_trainer import logger
|
| 16 |
+
from ltx_trainer.timestep_samplers import TimestepSampler
|
| 17 |
+
from ltx_trainer.training_strategies.base_strategy import (
|
| 18 |
+
DEFAULT_FPS,
|
| 19 |
+
ModelInputs,
|
| 20 |
+
TrainingStrategy,
|
| 21 |
+
TrainingStrategyConfigBase,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VideoToVideoConfig(TrainingStrategyConfigBase):
|
| 26 |
+
"""Configuration for video-to-video (IC-LoRA) training strategy."""
|
| 27 |
+
|
| 28 |
+
name: Literal["video_to_video"] = "video_to_video"
|
| 29 |
+
|
| 30 |
+
first_frame_conditioning_p: float = Field(
|
| 31 |
+
default=0.1,
|
| 32 |
+
description="Probability of conditioning on the first frame during training",
|
| 33 |
+
ge=0.0,
|
| 34 |
+
le=1.0,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
reference_latents_dir: str = Field(
|
| 38 |
+
default="reference_latents",
|
| 39 |
+
description="Directory name for latents of reference videos",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class VideoToVideoStrategy(TrainingStrategy):
|
| 44 |
+
"""Video-to-video training strategy for IC-LoRA.
|
| 45 |
+
This strategy implements training with reference video conditioning where:
|
| 46 |
+
- Reference latents (clean) are concatenated with target latents (noised)
|
| 47 |
+
- Video coordinates handle both reference and target sequences
|
| 48 |
+
- Loss is computed only on the target portion
|
| 49 |
+
Attributes:
|
| 50 |
+
reference_downscale_factor: The inferred downscale factor of reference videos.
|
| 51 |
+
This is computed from the first batch and cached for metadata export.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
config: VideoToVideoConfig
|
| 55 |
+
reference_downscale_factor: int | None
|
| 56 |
+
|
| 57 |
+
def __init__(self, config: VideoToVideoConfig):
|
| 58 |
+
"""Initialize strategy with configuration.
|
| 59 |
+
Args:
|
| 60 |
+
config: Video-to-video configuration
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(config)
|
| 63 |
+
self.reference_downscale_factor = None # Will be inferred from first batch
|
| 64 |
+
|
| 65 |
+
def get_data_sources(self) -> dict[str, str]:
|
| 66 |
+
"""IC-LoRA training requires latents, conditions, and reference latents."""
|
| 67 |
+
return {
|
| 68 |
+
"latents": "latents",
|
| 69 |
+
"conditions": "conditions",
|
| 70 |
+
self.config.reference_latents_dir: "ref_latents",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
def prepare_training_inputs( # noqa: PLR0915
|
| 74 |
+
self,
|
| 75 |
+
batch: dict[str, Any],
|
| 76 |
+
timestep_sampler: TimestepSampler,
|
| 77 |
+
) -> ModelInputs:
|
| 78 |
+
"""Prepare inputs for IC-LoRA training with reference videos."""
|
| 79 |
+
# Get pre-encoded latents - dataset provides uniform non-patchified format [B, C, F, H, W]
|
| 80 |
+
latents = batch["latents"]
|
| 81 |
+
target_latents = latents["latents"]
|
| 82 |
+
ref_latents = batch["ref_latents"]["latents"]
|
| 83 |
+
|
| 84 |
+
# Get dimensions
|
| 85 |
+
num_frames = latents["num_frames"][0].item()
|
| 86 |
+
height = latents["height"][0].item()
|
| 87 |
+
width = latents["width"][0].item()
|
| 88 |
+
|
| 89 |
+
ref_latents_info = batch["ref_latents"]
|
| 90 |
+
ref_frames = ref_latents_info["num_frames"][0].item()
|
| 91 |
+
ref_height = ref_latents_info["height"][0].item()
|
| 92 |
+
ref_width = ref_latents_info["width"][0].item()
|
| 93 |
+
|
| 94 |
+
# Infer reference downscale factor from dimension ratios
|
| 95 |
+
# This allows training with downscaled reference videos for efficiency
|
| 96 |
+
reference_downscale_factor = self._infer_reference_downscale_factor(
|
| 97 |
+
target_height=height,
|
| 98 |
+
target_width=width,
|
| 99 |
+
ref_height=ref_height,
|
| 100 |
+
ref_width=ref_width,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Cache the scale factor for metadata export (only on first batch)
|
| 104 |
+
if self.reference_downscale_factor is None:
|
| 105 |
+
self.reference_downscale_factor = reference_downscale_factor
|
| 106 |
+
elif self.reference_downscale_factor != reference_downscale_factor:
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"Inconsistent reference downscale factor across batches. "
|
| 109 |
+
f"First batch had factor={self.reference_downscale_factor}, "
|
| 110 |
+
f"but current batch has factor={reference_downscale_factor}. "
|
| 111 |
+
f"All training samples must use the same reference/target resolution ratio."
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Patchify latents: [B, C, F, H, W] -> [B, seq_len, C]
|
| 115 |
+
target_latents = self._video_patchifier.patchify(target_latents)
|
| 116 |
+
ref_latents = self._video_patchifier.patchify(ref_latents)
|
| 117 |
+
|
| 118 |
+
# Handle FPS
|
| 119 |
+
fps = latents.get("fps", None)
|
| 120 |
+
if fps is not None and not torch.all(fps == fps[0]):
|
| 121 |
+
logger.warning(
|
| 122 |
+
f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}"
|
| 123 |
+
)
|
| 124 |
+
fps = fps[0].item() if fps is not None else DEFAULT_FPS
|
| 125 |
+
|
| 126 |
+
# Get text embeddings (already processed by embedding connectors in trainer)
|
| 127 |
+
# Video-to-video uses only video embeddings
|
| 128 |
+
conditions = batch["conditions"]
|
| 129 |
+
prompt_embeds = conditions["video_prompt_embeds"]
|
| 130 |
+
prompt_attention_mask = conditions["prompt_attention_mask"]
|
| 131 |
+
|
| 132 |
+
batch_size = target_latents.shape[0]
|
| 133 |
+
ref_seq_len = ref_latents.shape[1]
|
| 134 |
+
target_seq_len = target_latents.shape[1]
|
| 135 |
+
device = target_latents.device
|
| 136 |
+
dtype = target_latents.dtype
|
| 137 |
+
|
| 138 |
+
# Create conditioning mask
|
| 139 |
+
# Reference tokens are always conditioning (timestep=0)
|
| 140 |
+
ref_conditioning_mask = torch.ones(batch_size, ref_seq_len, dtype=torch.bool, device=device)
|
| 141 |
+
|
| 142 |
+
# Target tokens: check for first frame conditioning
|
| 143 |
+
target_conditioning_mask = self._create_first_frame_conditioning_mask(
|
| 144 |
+
batch_size=batch_size,
|
| 145 |
+
sequence_length=target_seq_len,
|
| 146 |
+
height=height,
|
| 147 |
+
width=width,
|
| 148 |
+
device=device,
|
| 149 |
+
first_frame_conditioning_p=self.config.first_frame_conditioning_p,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Combined conditioning mask
|
| 153 |
+
conditioning_mask = torch.cat([ref_conditioning_mask, target_conditioning_mask], dim=1)
|
| 154 |
+
|
| 155 |
+
# Sample noise and sigmas for target
|
| 156 |
+
sigmas = timestep_sampler.sample_for(target_latents)
|
| 157 |
+
noise = torch.randn_like(target_latents)
|
| 158 |
+
sigmas_expanded = sigmas.view(-1, 1, 1)
|
| 159 |
+
|
| 160 |
+
# Apply noise to target
|
| 161 |
+
noisy_target = (1 - sigmas_expanded) * target_latents + sigmas_expanded * noise
|
| 162 |
+
|
| 163 |
+
# For first frame conditioning in target, use clean latents
|
| 164 |
+
target_conditioning_mask_expanded = target_conditioning_mask.unsqueeze(-1)
|
| 165 |
+
noisy_target = torch.where(target_conditioning_mask_expanded, target_latents, noisy_target)
|
| 166 |
+
|
| 167 |
+
# Targets for loss computation
|
| 168 |
+
targets = noise - target_latents
|
| 169 |
+
|
| 170 |
+
# Concatenate reference (clean) and target (noisy)
|
| 171 |
+
combined_latents = torch.cat([ref_latents, noisy_target], dim=1)
|
| 172 |
+
|
| 173 |
+
# Create per-token timesteps
|
| 174 |
+
timesteps = self._create_per_token_timesteps(conditioning_mask, sigmas.squeeze())
|
| 175 |
+
|
| 176 |
+
# Generate positions for reference and target separately, then concatenate
|
| 177 |
+
ref_positions = self._get_video_positions(
|
| 178 |
+
num_frames=ref_frames,
|
| 179 |
+
height=ref_height,
|
| 180 |
+
width=ref_width,
|
| 181 |
+
batch_size=batch_size,
|
| 182 |
+
fps=fps,
|
| 183 |
+
device=device,
|
| 184 |
+
dtype=dtype,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Scale reference positions to match target coordinate space
|
| 188 |
+
# This maps ref positions from (0, ref_H, ref_W) to (0, target_H, target_W)
|
| 189 |
+
# Position tensor shape: [B, 3, seq_len, 2] where dim 1 is (time, height, width)
|
| 190 |
+
if reference_downscale_factor != 1:
|
| 191 |
+
ref_positions = ref_positions.clone()
|
| 192 |
+
ref_positions[:, 1, ...] *= reference_downscale_factor # height axis
|
| 193 |
+
ref_positions[:, 2, ...] *= reference_downscale_factor # width axis
|
| 194 |
+
# Time axis (index 0) remains unchanged
|
| 195 |
+
|
| 196 |
+
target_positions = self._get_video_positions(
|
| 197 |
+
num_frames=num_frames,
|
| 198 |
+
height=height,
|
| 199 |
+
width=width,
|
| 200 |
+
batch_size=batch_size,
|
| 201 |
+
fps=fps,
|
| 202 |
+
device=device,
|
| 203 |
+
dtype=dtype,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Concatenate positions along sequence dimension
|
| 207 |
+
positions = torch.cat([ref_positions, target_positions], dim=2)
|
| 208 |
+
|
| 209 |
+
# Create video Modality
|
| 210 |
+
video_modality = Modality(
|
| 211 |
+
enabled=True,
|
| 212 |
+
latent=combined_latents,
|
| 213 |
+
sigma=sigmas,
|
| 214 |
+
timesteps=timesteps,
|
| 215 |
+
positions=positions,
|
| 216 |
+
context=prompt_embeds,
|
| 217 |
+
context_mask=prompt_attention_mask,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Loss mask: only compute loss on non-conditioning target tokens
|
| 221 |
+
# Reference tokens: all False (no loss)
|
| 222 |
+
# Target tokens: True where not conditioning
|
| 223 |
+
ref_loss_mask = torch.zeros(batch_size, ref_seq_len, dtype=torch.bool, device=device)
|
| 224 |
+
target_loss_mask = ~target_conditioning_mask
|
| 225 |
+
video_loss_mask = torch.cat([ref_loss_mask, target_loss_mask], dim=1)
|
| 226 |
+
|
| 227 |
+
return ModelInputs(
|
| 228 |
+
video=video_modality,
|
| 229 |
+
audio=None,
|
| 230 |
+
video_targets=targets,
|
| 231 |
+
audio_targets=None,
|
| 232 |
+
video_loss_mask=video_loss_mask,
|
| 233 |
+
audio_loss_mask=None,
|
| 234 |
+
ref_seq_len=ref_seq_len,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def compute_loss(
|
| 238 |
+
self,
|
| 239 |
+
video_pred: Tensor,
|
| 240 |
+
_audio_pred: Tensor | None,
|
| 241 |
+
inputs: ModelInputs,
|
| 242 |
+
) -> Tensor:
|
| 243 |
+
"""Compute masked loss only on target portion."""
|
| 244 |
+
# Extract target portion of prediction
|
| 245 |
+
ref_seq_len = inputs.ref_seq_len
|
| 246 |
+
target_pred = video_pred[:, ref_seq_len:, :]
|
| 247 |
+
|
| 248 |
+
# Get target portion of loss mask
|
| 249 |
+
target_loss_mask = inputs.video_loss_mask[:, ref_seq_len:]
|
| 250 |
+
|
| 251 |
+
# Compute loss
|
| 252 |
+
loss = (target_pred - inputs.video_targets).pow(2)
|
| 253 |
+
|
| 254 |
+
# Apply loss mask
|
| 255 |
+
loss_mask = target_loss_mask.unsqueeze(-1).float()
|
| 256 |
+
loss = loss.mul(loss_mask).div(loss_mask.mean())
|
| 257 |
+
|
| 258 |
+
return loss.mean()
|
| 259 |
+
|
| 260 |
+
def get_checkpoint_metadata(self) -> dict[str, Any]:
|
| 261 |
+
"""Get metadata for checkpoint files."""
|
| 262 |
+
metadata: dict[str, Any] = {}
|
| 263 |
+
# Always include reference_downscale_factor for IC-LoRAs so inference
|
| 264 |
+
# pipelines know the expected scale factor for reference videos.
|
| 265 |
+
if self.reference_downscale_factor is not None:
|
| 266 |
+
metadata["reference_downscale_factor"] = self.reference_downscale_factor
|
| 267 |
+
return metadata
|
| 268 |
+
|
| 269 |
+
@staticmethod
|
| 270 |
+
def _infer_reference_downscale_factor(
|
| 271 |
+
target_height: int,
|
| 272 |
+
target_width: int,
|
| 273 |
+
ref_height: int,
|
| 274 |
+
ref_width: int,
|
| 275 |
+
) -> int:
|
| 276 |
+
"""Infer the reference downscale factor from target and reference dimensions."""
|
| 277 |
+
# If dimensions match, no scaling needed
|
| 278 |
+
if target_height == ref_height and target_width == ref_width:
|
| 279 |
+
return 1
|
| 280 |
+
|
| 281 |
+
# Calculate scale factors for each dimension
|
| 282 |
+
if target_height % ref_height != 0 or target_width % ref_width != 0:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"Target dimensions ({target_height}x{target_width}) must be exact multiples "
|
| 285 |
+
f"of reference dimensions ({ref_height}x{ref_width})"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
scale_h = target_height // ref_height
|
| 289 |
+
scale_w = target_width // ref_width
|
| 290 |
+
|
| 291 |
+
if scale_h != scale_w:
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"Reference scale must be uniform. Got height scale {scale_h} and width scale {scale_w}. "
|
| 294 |
+
f"Target: {target_height}x{target_width}, Reference: {ref_height}x{ref_width}"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if scale_h < 1:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"Reference dimensions ({ref_height}x{ref_width}) cannot be larger than "
|
| 300 |
+
f"target dimensions ({target_height}x{target_width})"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
return scale_h
|
packages/ltx-trainer/src/ltx_trainer/utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import ExifTags, Image, ImageCms, ImageOps
|
| 7 |
+
from PIL.Image import Image as PilImage
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage:
|
| 11 |
+
"""
|
| 12 |
+
Opens an image file, applies rotation (if it's set in metadata) and converts it
|
| 13 |
+
to the sRGB color space respecting the original image color space .
|
| 14 |
+
Args:
|
| 15 |
+
image_path: Path to the image file
|
| 16 |
+
Returns:
|
| 17 |
+
PIL Image in sRGB color space
|
| 18 |
+
"""
|
| 19 |
+
exif_colorspace_srgb = 1
|
| 20 |
+
|
| 21 |
+
with Image.open(image_path) as img_raw:
|
| 22 |
+
img = ImageOps.exif_transpose(img_raw)
|
| 23 |
+
|
| 24 |
+
input_icc_profile = img.info.get("icc_profile")
|
| 25 |
+
|
| 26 |
+
# Try to convert to sRGB if the image has ICC profile metadata
|
| 27 |
+
srgb_profile = ImageCms.createProfile(colorSpace="sRGB")
|
| 28 |
+
if input_icc_profile is not None:
|
| 29 |
+
input_profile = ImageCms.ImageCmsProfile(io.BytesIO(input_icc_profile))
|
| 30 |
+
srgb_img = ImageCms.profileToProfile(img, input_profile, srgb_profile, outputMode="RGB")
|
| 31 |
+
else:
|
| 32 |
+
# Try fall back to checking EXIF
|
| 33 |
+
exif_data = img.getexif()
|
| 34 |
+
if exif_data is not None:
|
| 35 |
+
# Assume sRGB if no ICC profile and EXIF has no ColorSpace tag
|
| 36 |
+
color_space_value = exif_data.get(ExifTags.Base.ColorSpace.value)
|
| 37 |
+
if color_space_value is not None and color_space_value != exif_colorspace_srgb:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
"Image has colorspace tag in EXIF but it isn't set to sRGB,"
|
| 40 |
+
" conversion is not supported."
|
| 41 |
+
f" EXIF ColorSpace tag value is {color_space_value}",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
srgb_img = img.convert("RGB")
|
| 45 |
+
|
| 46 |
+
# Set sRGB profile in metadata since now the image is assumed to be in sRGB.
|
| 47 |
+
srgb_profile_data = ImageCms.ImageCmsProfile(srgb_profile).tobytes()
|
| 48 |
+
srgb_img.info["icc_profile"] = srgb_profile_data
|
| 49 |
+
|
| 50 |
+
return srgb_img
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def save_image(image_tensor: torch.Tensor, output_path: Path | str) -> None:
|
| 54 |
+
"""Save an image tensor to a file.
|
| 55 |
+
Args:
|
| 56 |
+
image_tensor: Image tensor of shape [C, H, W] or [C, 1, H, W] in range [0, 1] or [0, 255].
|
| 57 |
+
C must be 3 (RGB).
|
| 58 |
+
output_path: Path to save the image (any PIL-supported format, e.g., .png or .jpg)
|
| 59 |
+
"""
|
| 60 |
+
output_path = Path(output_path)
|
| 61 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
# Handle [C, 1, H, W] format (single frame from video tensor)
|
| 64 |
+
if image_tensor.ndim == 4:
|
| 65 |
+
# Squeeze frame dimension: [C, 1, H, W] -> [C, H, W]
|
| 66 |
+
if image_tensor.shape[1] == 1:
|
| 67 |
+
image_tensor = image_tensor.squeeze(1)
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(f"Expected single-frame tensor with shape [C, 1, H, W], got shape {image_tensor.shape}")
|
| 70 |
+
|
| 71 |
+
if image_tensor.ndim != 3:
|
| 72 |
+
raise ValueError(f"Expected 3D tensor [C, H, W], got {image_tensor.ndim}D tensor")
|
| 73 |
+
|
| 74 |
+
if image_tensor.shape[0] != 3:
|
| 75 |
+
raise ValueError(f"Expected 3 channels (RGB), got {image_tensor.shape[0]} channels")
|
| 76 |
+
|
| 77 |
+
# Normalize to [0, 255] uint8
|
| 78 |
+
if torch.is_floating_point(image_tensor) and image_tensor.max() <= 1.0:
|
| 79 |
+
image_tensor = image_tensor * 255
|
| 80 |
+
|
| 81 |
+
# Clamp to valid uint8 range to prevent overflow
|
| 82 |
+
image_tensor = image_tensor.clamp(0, 255)
|
| 83 |
+
|
| 84 |
+
# [C, H, W] -> [H, W, C]
|
| 85 |
+
image_np: np.ndarray = image_tensor.permute(1, 2, 0).to(torch.uint8).cpu().numpy()
|
| 86 |
+
|
| 87 |
+
# Save using PIL
|
| 88 |
+
Image.fromarray(image_np).save(output_path)
|
packages/ltx-trainer/templates/model_card.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- ltx-2
|
| 4 |
+
- ltx-video
|
| 5 |
+
- text-to-video
|
| 6 |
+
- audio-video
|
| 7 |
+
pinned: true
|
| 8 |
+
language:
|
| 9 |
+
- en
|
| 10 |
+
license: other
|
| 11 |
+
pipeline_tag: text-to-video
|
| 12 |
+
library_name: diffusers
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# {model_name}
|
| 16 |
+
|
| 17 |
+
This is a fine-tuned version of [`{base_model}`]({base_model_link}) trained on custom data.
|
| 18 |
+
|
| 19 |
+
## Model Details
|
| 20 |
+
|
| 21 |
+
- **Base Model:** [`{base_model}`]({base_model_link})
|
| 22 |
+
- **Training Type:** {training_type}
|
| 23 |
+
- **Training Steps:** {training_steps}
|
| 24 |
+
- **Learning Rate:** {learning_rate}
|
| 25 |
+
- **Batch Size:** {batch_size}
|
| 26 |
+
|
| 27 |
+
## Sample Outputs
|
| 28 |
+
|
| 29 |
+
| | | | |
|
| 30 |
+
|:---:|:---:|:---:|:---:|
|
| 31 |
+
{sample_grid}
|
| 32 |
+
|
| 33 |
+
## Usage
|
| 34 |
+
|
| 35 |
+
This model is designed to be used with the LTX-2 (Lightricks Audio-Video) pipeline.
|
| 36 |
+
|
| 37 |
+
### 🔌 Using Trained LoRAs in ComfyUI
|
| 38 |
+
|
| 39 |
+
In order to use the trained LoRA in ComfyUI, follow these steps:
|
| 40 |
+
|
| 41 |
+
1. Copy your trained LoRA checkpoint (`.safetensors` file) to the `models/loras` folder in your ComfyUI installation.
|
| 42 |
+
2. In your ComfyUI workflow:
|
| 43 |
+
- Add the "Load LoRA" node to choose your LoRA file
|
| 44 |
+
- Connect it to the "Load Checkpoint" node to apply the LoRA to the base model
|
| 45 |
+
|
| 46 |
+
You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the
|
| 47 |
+
official [LTX-2 repository](https://github.com/Lightricks/LTX-2).
|
| 48 |
+
|
| 49 |
+
### Example Prompts
|
| 50 |
+
|
| 51 |
+
{validation_prompts}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
This model inherits the license of the base model ([`{base_model}`]({base_model_link})).
|
| 55 |
+
|
| 56 |
+
## Acknowledgments
|
| 57 |
+
|
| 58 |
+
- Base model: [Lightricks](https://huggingface.co/Lightricks/LTX-2)
|
| 59 |
+
- Trainer: [LTX-2](https://github.com/Lightricks/LTX-2)
|