jiuhai commited on
Commit
a3c20e1
·
verified ·
1 Parent(s): 4ab332d

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. packages/ltx-core/pyproject.toml +55 -0
  2. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  3. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-312.pyc +0 -0
  4. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-312.pyc +0 -0
  5. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-312.pyc +0 -0
  6. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-312.pyc +0 -0
  7. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-312.pyc +0 -0
  8. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-312.pyc +0 -0
  9. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-312.pyc +0 -0
  10. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-312.pyc +0 -0
  11. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-312.pyc +0 -0
  12. packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-312.pyc +0 -0
  13. packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +202 -0
  14. packages/ltx-trainer/configs/accelerate/ddp.yaml +16 -0
  15. packages/ltx-trainer/configs/accelerate/ddp_compile.yaml +21 -0
  16. packages/ltx-trainer/configs/accelerate/fsdp.yaml +29 -0
  17. packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml +34 -0
  18. packages/ltx-trainer/configs/ltx2_av_lora.yaml +313 -0
  19. packages/ltx-trainer/configs/ltx2_av_lora_low_vram.yaml +325 -0
  20. packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml +329 -0
  21. packages/ltx-trainer/docs/configuration-reference.md +372 -0
  22. packages/ltx-trainer/docs/custom-training-strategies.md +510 -0
  23. packages/ltx-trainer/docs/dataset-preparation.md +342 -0
  24. packages/ltx-trainer/docs/quick-start.md +130 -0
  25. packages/ltx-trainer/docs/training-guide.md +203 -0
  26. packages/ltx-trainer/docs/training-modes.md +277 -0
  27. packages/ltx-trainer/docs/troubleshooting.md +300 -0
  28. packages/ltx-trainer/docs/utility-scripts.md +274 -0
  29. packages/ltx-trainer/scripts/caption_videos.py +486 -0
  30. packages/ltx-trainer/scripts/compute_reference.py +288 -0
  31. packages/ltx-trainer/scripts/decode_latents.py +369 -0
  32. packages/ltx-trainer/scripts/process_captions.py +435 -0
  33. packages/ltx-trainer/scripts/process_dataset.py +317 -0
  34. packages/ltx-trainer/scripts/process_videos.py +1039 -0
  35. packages/ltx-trainer/scripts/split_scenes.py +417 -0
  36. packages/ltx-trainer/scripts/train.py +64 -0
  37. packages/ltx-trainer/src/ltx_trainer/__pycache__/__init__.cpython-312.pyc +0 -0
  38. packages/ltx-trainer/src/ltx_trainer/__pycache__/model_loader.cpython-312.pyc +0 -0
  39. packages/ltx-trainer/src/ltx_trainer/captioning.py +401 -0
  40. packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py +85 -0
  41. packages/ltx-trainer/src/ltx_trainer/gpu_utils.py +90 -0
  42. packages/ltx-trainer/src/ltx_trainer/progress.py +236 -0
  43. packages/ltx-trainer/src/ltx_trainer/quantization.py +195 -0
  44. packages/ltx-trainer/src/ltx_trainer/trainer.py +1000 -0
  45. packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py +58 -0
  46. packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py +262 -0
  47. packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py +291 -0
  48. packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py +303 -0
  49. packages/ltx-trainer/src/ltx_trainer/utils.py +88 -0
  50. packages/ltx-trainer/templates/model_card.md +59 -0
packages/ltx-core/pyproject.toml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ltx-core"
3
+ version = "1.0.0"
4
+ description = "Core implementation of Lightricks' LTX-2 model"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "torch~=2.7",
9
+ "torchaudio",
10
+ "einops",
11
+ "numpy",
12
+ "transformers>=4.52",
13
+ "safetensors",
14
+ "accelerate",
15
+ "scipy>=1.14",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ xformers = ["xformers"]
20
+ fp8-trtllm = [
21
+ "tensorrt-llm==1.0.0",
22
+ "onnx>=1.16.0,<1.20.0",
23
+ "openmpi",
24
+ ]
25
+
26
+ [tool.uv]
27
+ conflicts = [
28
+ [
29
+ { extra = "xformers" },
30
+ { extra = "fp8-trtllm" },
31
+ ],
32
+ ]
33
+
34
+ [tool.uv.sources]
35
+ xformers = { index = "pytorch" }
36
+ tensorrt-llm = { index = "nvidia" }
37
+
38
+ [[tool.uv.index]]
39
+ name = "pytorch"
40
+ url = "https://download.pytorch.org/whl/cu129"
41
+ explicit = true
42
+
43
+ [[tool.uv.index]]
44
+ name = "nvidia"
45
+ url = "https://pypi.nvidia.com/"
46
+ explicit = true
47
+
48
+ [build-system]
49
+ requires = ["uv_build>=0.9.8,<0.10.0"]
50
+ build-backend = "uv_build"
51
+
52
+ [dependency-groups]
53
+ dev = [
54
+ "scikit-image>=0.25.2",
55
+ ]
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (624 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/attention.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/feed_forward.cpython-312.pyc ADDED
Binary file (1.52 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/gelu_approx.cpython-312.pyc ADDED
Binary file (1.28 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/modality.cpython-312.pyc ADDED
Binary file (2.29 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/model_configurator.cpython-312.pyc ADDED
Binary file (9.1 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/rope.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/text_projection.cpython-312.pyc ADDED
Binary file (2.75 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/timestep_embedding.cpython-312.pyc ADDED
Binary file (7.2 kB). View file
 
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (18 kB). View file
 
packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor
6
+
7
+ from ltx_core.loader.module_ops import ModuleOps
8
+ from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
9
+ from ltx_core.utils import find_matching_file
10
+
11
+
12
+ class GemmaTextEncoder(torch.nn.Module):
13
+ """Pure Gemma text encoder — runs the LLM and returns raw hidden states.
14
+ Prompt enhancement (generate) is also supported since the full
15
+ Gemma3ForConditionalGeneration model (including lm_head) is loaded.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model: Gemma3ForConditionalGeneration | None = None,
21
+ tokenizer: LTXVGemmaTokenizer | None = None,
22
+ processor: Gemma3Processor | None = None,
23
+ dtype: torch.dtype = torch.bfloat16,
24
+ ):
25
+ super().__init__()
26
+ self.model = model
27
+ self.tokenizer = tokenizer
28
+ self.processor = processor
29
+ self._dtype = dtype
30
+
31
+ def encode(
32
+ self,
33
+ text: str,
34
+ padding_side: str = "left", # noqa: ARG002
35
+ ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
36
+ """Run Gemma LLM and return raw hidden states + attention mask.
37
+ Calls the inner model (self.model.model) to skip lm_head logits computation (~500 MiB saving).
38
+ Returns:
39
+ (hidden_states, attention_mask) where hidden_states is a tuple of per-layer tensors.
40
+ """
41
+ token_pairs = self.tokenizer.tokenize_with_weights(text)["gemma"]
42
+ input_ids = torch.tensor([[t[0] for t in token_pairs]], device=self.model.device)
43
+ attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=self.model.device)
44
+ outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
45
+ hidden_states = outputs.hidden_states
46
+ del outputs
47
+ return hidden_states, attention_mask
48
+
49
+ # --- Prompt enhancement methods ---
50
+
51
+ def _enhance(
52
+ self,
53
+ messages: list[dict[str, str]],
54
+ image: torch.Tensor | None = None,
55
+ max_new_tokens: int = 512,
56
+ seed: int = 10,
57
+ ) -> str:
58
+ text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
+
60
+ model_inputs = self.processor(
61
+ text=text,
62
+ images=image,
63
+ return_tensors="pt",
64
+ ).to(self.model.device)
65
+ pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0
66
+ model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id)
67
+
68
+ with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]):
69
+ torch.manual_seed(seed)
70
+ outputs = self.model.generate(
71
+ **model_inputs,
72
+ max_new_tokens=max_new_tokens,
73
+ do_sample=True,
74
+ temperature=0.7,
75
+ )
76
+ generated_ids = outputs[0][len(model_inputs.input_ids[0]) :]
77
+ enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
78
+
79
+ return enhanced_prompt
80
+
81
+ def enhance_t2v(
82
+ self,
83
+ prompt: str,
84
+ max_new_tokens: int = 512,
85
+ system_prompt: str | None = None,
86
+ seed: int = 10,
87
+ ) -> str:
88
+ """Enhance a text prompt for T2V generation."""
89
+ system_prompt = system_prompt or self.default_gemma_t2v_system_prompt
90
+
91
+ messages = [
92
+ {"role": "system", "content": system_prompt},
93
+ {"role": "user", "content": f"user prompt: {prompt}"},
94
+ ]
95
+
96
+ return self._enhance(messages, max_new_tokens=max_new_tokens, seed=seed)
97
+
98
+ def enhance_i2v(
99
+ self,
100
+ prompt: str,
101
+ image: torch.Tensor,
102
+ max_new_tokens: int = 512,
103
+ system_prompt: str | None = None,
104
+ seed: int = 10,
105
+ ) -> str:
106
+ """Enhance a text prompt for I2V generation using a reference image."""
107
+ system_prompt = system_prompt or self.default_gemma_i2v_system_prompt
108
+ messages = [
109
+ {"role": "system", "content": system_prompt},
110
+ {
111
+ "role": "user",
112
+ "content": [
113
+ {"type": "image"},
114
+ {"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
115
+ ],
116
+ },
117
+ ]
118
+ return self._enhance(messages, image=image, max_new_tokens=max_new_tokens, seed=seed)
119
+
120
+ @functools.cached_property
121
+ def default_gemma_i2v_system_prompt(self) -> str:
122
+ return _load_system_prompt("gemma_i2v_system_prompt.txt")
123
+
124
+ @functools.cached_property
125
+ def default_gemma_t2v_system_prompt(self) -> str:
126
+ return _load_system_prompt("gemma_t2v_system_prompt.txt")
127
+
128
+
129
+ # --- Standalone utility functions ---
130
+
131
+
132
+ @functools.lru_cache(maxsize=2)
133
+ def _load_system_prompt(prompt_name: str) -> str:
134
+ with open(Path(__file__).parent / "prompts" / f"{prompt_name}", "r") as f:
135
+ return f.read()
136
+
137
+
138
+ def _cat_with_padding(
139
+ tensor: torch.Tensor,
140
+ padding_length: int,
141
+ value: int | float,
142
+ ) -> torch.Tensor:
143
+ """Concatenate a tensor with a padding tensor of the given value."""
144
+ return torch.cat(
145
+ [
146
+ tensor,
147
+ torch.full(
148
+ (1, padding_length),
149
+ value,
150
+ dtype=tensor.dtype,
151
+ device=tensor.device,
152
+ ),
153
+ ],
154
+ dim=1,
155
+ )
156
+
157
+
158
+ def _pad_inputs_for_attention_alignment(
159
+ model_inputs: dict[str, torch.Tensor],
160
+ pad_token_id: int = 0,
161
+ alignment: int = 8,
162
+ ) -> dict[str, torch.Tensor]:
163
+ """Pad sequence length to multiple of alignment for Flash Attention compatibility."""
164
+ seq_len = model_inputs.input_ids.shape[1]
165
+ padded_len = ((seq_len + alignment - 1) // alignment) * alignment
166
+ padding_length = padded_len - seq_len
167
+
168
+ if padding_length > 0:
169
+ model_inputs["input_ids"] = _cat_with_padding(model_inputs.input_ids, padding_length, pad_token_id)
170
+ model_inputs["attention_mask"] = _cat_with_padding(model_inputs.attention_mask, padding_length, 0)
171
+ if "token_type_ids" in model_inputs and model_inputs["token_type_ids"] is not None:
172
+ model_inputs["token_type_ids"] = _cat_with_padding(model_inputs["token_type_ids"], padding_length, 0)
173
+
174
+ return model_inputs
175
+
176
+
177
+ def module_ops_from_gemma_root(gemma_root: str) -> tuple[ModuleOps, ...]:
178
+ tokenizer_root = str(find_matching_file(gemma_root, "tokenizer.model").parent)
179
+ processor_root = str(find_matching_file(gemma_root, "preprocessor_config.json").parent)
180
+
181
+ def load_tokenizer(module: GemmaTextEncoder) -> GemmaTextEncoder:
182
+ module.tokenizer = LTXVGemmaTokenizer(tokenizer_root, 1024)
183
+ return module
184
+
185
+ def load_processor(module: GemmaTextEncoder) -> GemmaTextEncoder:
186
+ image_processor = AutoImageProcessor.from_pretrained(processor_root, local_files_only=True)
187
+ if not module.tokenizer:
188
+ raise ValueError("Tokenizer model operation must be performed before processor model operation")
189
+ module.processor = Gemma3Processor(image_processor=image_processor, tokenizer=module.tokenizer.tokenizer)
190
+ return module
191
+
192
+ tokenizer_load_ops = ModuleOps(
193
+ "TokenizerLoad",
194
+ matcher=lambda module: isinstance(module, GemmaTextEncoder) and module.tokenizer is None,
195
+ mutator=load_tokenizer,
196
+ )
197
+ processor_load_ops = ModuleOps(
198
+ "ProcessorLoad",
199
+ matcher=lambda module: isinstance(module, GemmaTextEncoder) and module.processor is None,
200
+ mutator=load_processor,
201
+ )
202
+ return (tokenizer_load_ops, processor_load_ops)
packages/ltx-trainer/configs/accelerate/ddp.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 4
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
packages/ltx-trainer/configs/accelerate/ddp_compile.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ dynamo_config:
3
+ dynamo_backend: INDUCTOR
4
+ dynamo_mode: default
5
+ dynamo_use_fullgraph: false
6
+ dynamo_use_dynamic: true
7
+ debug: false
8
+ distributed_type: MULTI_GPU
9
+ downcast_bf16: 'no'
10
+ enable_cpu_affinity: false
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 4
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: [ ]
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
packages/ltx-trainer/configs/accelerate/fsdp.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: FSDP
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ fsdp_config:
7
+ fsdp_activation_checkpointing: false
8
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9
+ fsdp_backward_prefetch: BACKWARD_PRE
10
+ fsdp_cpu_ram_efficient_loading: true
11
+ fsdp_forward_prefetch: false
12
+ fsdp_offload_params: false
13
+ fsdp_reshard_after_forward: FULL_SHARD
14
+ fsdp_state_dict_type: SHARDED_STATE_DICT
15
+ fsdp_sync_module_states: true
16
+ fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock
17
+ fsdp_use_orig_params: true
18
+ fsdp_version: 1
19
+ machine_rank: 0
20
+ main_training_function: main
21
+ mixed_precision: bf16
22
+ num_machines: 1
23
+ num_processes: 4
24
+ rdzv_backend: static
25
+ same_network: true
26
+ tpu_env: []
27
+ tpu_use_cluster: false
28
+ tpu_use_sudo: false
29
+ use_cpu: false
packages/ltx-trainer/configs/accelerate/fsdp_compile.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: FSDP
4
+ downcast_bf16: 'no'
5
+ dynamo_config:
6
+ dynamo_backend: INDUCTOR
7
+ dynamo_mode: default
8
+ dynamo_use_fullgraph: false
9
+ dynamo_use_dynamic: true
10
+ enable_cpu_affinity: false
11
+ fsdp_config:
12
+ fsdp_activation_checkpointing: false
13
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
14
+ fsdp_backward_prefetch: BACKWARD_PRE
15
+ fsdp_cpu_ram_efficient_loading: true
16
+ fsdp_forward_prefetch: false
17
+ fsdp_offload_params: false
18
+ fsdp_reshard_after_forward: FULL_SHARD
19
+ fsdp_state_dict_type: SHARDED_STATE_DICT
20
+ fsdp_sync_module_states: true
21
+ fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock
22
+ fsdp_use_orig_params: true
23
+ fsdp_version: 1
24
+ machine_rank: 0
25
+ main_training_function: main
26
+ mixed_precision: bf16
27
+ num_machines: 1
28
+ num_processes: 4
29
+ rdzv_backend: static
30
+ same_network: true
31
+ tpu_env: []
32
+ tpu_use_cluster: false
33
+ tpu_use_sudo: false
34
+ use_cpu: false
packages/ltx-trainer/configs/ltx2_av_lora.yaml ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # LTX-2 Audio-Video LoRA Training Configuration
3
+ # =============================================================================
4
+ #
5
+ # This configuration is for training LoRA adapters on the LTX-2 model for
6
+ # text-to-video generation. It supports both video-only and joint audio-video
7
+ # training modes.
8
+ #
9
+ # Use this configuration when you want to:
10
+ # - Fine-tune LTX-2 on your own video dataset
11
+ # - Train with or without audio generation
12
+ # - Create custom video generation styles or audiovisual concepts
13
+ #
14
+ # Dataset structure for text-to-video training:
15
+ # preprocessed_data_root/
16
+ # ├── latents/ # Video latents (VAE-encoded videos)
17
+ # ├── conditions/ # Text embeddings for each video
18
+ # └── audio_latents/ # Audio latents (only if with_audio: true)
19
+ #
20
+ # =============================================================================
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # Model Configuration
24
+ # -----------------------------------------------------------------------------
25
+ # Specifies the base model to fine-tune and the training mode.
26
+ model:
27
+ # Path to the LTX-2 model checkpoint (.safetensors file)
28
+ # This should be a local path to your downloaded model
29
+ model_path: "path/to/ltx-2-model.safetensors"
30
+
31
+ # Path to the text encoder model directory
32
+ # For LTX-2, this is typically the Gemma-based text encoder
33
+ text_encoder_path: "path/to/gemma-text-encoder"
34
+
35
+ # Training mode: "lora" for efficient adapter training, "full" for full fine-tuning
36
+ # LoRA is recommended for most use cases (faster, less memory, prevents overfitting)
37
+ training_mode: "lora"
38
+
39
+ # Optional: Path to resume training from a checkpoint
40
+ # Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint)
41
+ load_checkpoint: null
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # LoRA Configuration
45
+ # -----------------------------------------------------------------------------
46
+ # Controls the Low-Rank Adaptation parameters for efficient fine-tuning.
47
+ lora:
48
+ # Rank of the LoRA matrices (higher = more capacity but more parameters)
49
+ # Typical values: 8, 16, 32, 64. Start with 32 for general fine-tuning.
50
+ rank: 32
51
+
52
+ # Alpha scaling factor (usually set equal to rank)
53
+ # The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0
54
+ alpha: 32
55
+
56
+ # Dropout probability for LoRA layers (0.0 = no dropout)
57
+ # Can help with regularization if overfitting occurs
58
+ dropout: 0.0
59
+
60
+ # Which transformer modules to apply LoRA to
61
+ # The LTX-2 transformer has separate attention and FFN blocks for video and audio:
62
+ #
63
+ # VIDEO MODULES:
64
+ # - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention)
65
+ # - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text)
66
+ # - ff.net.0.proj, ff.net.2 (video feed-forward)
67
+ #
68
+ # AUDIO MODULES:
69
+ # - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention)
70
+ # - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text)
71
+ # - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward)
72
+ #
73
+ # AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction):
74
+ # - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0
75
+ # (Q from video, K/V from audio - allows video to attend to audio features)
76
+ # - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0
77
+ # (Q from audio, K/V from video - allows audio to attend to video features)
78
+ #
79
+ # Using short patterns like "to_k" matches ALL attention modules (video, audio, and cross-modal).
80
+ # For audio-video training, this is the recommended approach.
81
+ target_modules:
82
+ # Attention layers (matches both video and audio branches)
83
+ - "to_k"
84
+ - "to_q"
85
+ - "to_v"
86
+ - "to_out.0"
87
+ # Uncomment below to also train feed-forward layers (can increase the LoRA's capacity):
88
+ # - "ff.net.0.proj"
89
+ # - "ff.net.2"
90
+ # - "audio_ff.net.0.proj"
91
+ # - "audio_ff.net.2"
92
+
93
+ # -----------------------------------------------------------------------------
94
+ # Training Strategy Configuration
95
+ # -----------------------------------------------------------------------------
96
+ # Defines the text-to-video training approach.
97
+ training_strategy:
98
+ # Strategy name: "text_to_video" for standard text-to-video training
99
+ name: "text_to_video"
100
+
101
+ # Probability of conditioning on the first frame during training
102
+ # Higher values train the model to perform better in image-to-video (I2V) mode,
103
+ # where a clean first frame is provided and the model generates the rest of the video
104
+ # Increase this value to train the model to perform better in image-to-video (I2V) mode
105
+ first_frame_conditioning_p: 0.5
106
+
107
+ # Enable joint audio-video training
108
+ # Set to true if your dataset includes audio and you want to train the audio branch
109
+ with_audio: true
110
+
111
+ # Directory name (within preprocessed_data_root) containing audio latents
112
+ # Only used when with_audio is true
113
+ audio_latents_dir: "audio_latents"
114
+
115
+ # -----------------------------------------------------------------------------
116
+ # Optimization Configuration
117
+ # -----------------------------------------------------------------------------
118
+ # Controls the training optimization parameters.
119
+ optimization:
120
+ # Learning rate for the optimizer
121
+ # Typical range for LoRA: 1e-5 to 1e-4
122
+ learning_rate: 1e-4
123
+
124
+ # Total number of training steps
125
+ steps: 2000
126
+
127
+ # Batch size per GPU
128
+ # Reduce if running out of memory
129
+ batch_size: 1
130
+
131
+ # Number of gradient accumulation steps
132
+ # Effective batch size = batch_size * gradient_accumulation_steps * num_gpus
133
+ gradient_accumulation_steps: 1
134
+
135
+ # Maximum gradient norm for clipping (helps training stability)
136
+ max_grad_norm: 1.0
137
+
138
+ # Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient)
139
+ optimizer_type: "adamw"
140
+
141
+ # Learning rate scheduler type
142
+ # Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial"
143
+ scheduler_type: "linear"
144
+
145
+ # Additional scheduler parameters (depends on scheduler_type)
146
+ scheduler_params: { }
147
+
148
+ # Enable gradient checkpointing to reduce memory usage
149
+ # Recommended for training with limited GPU memory
150
+ enable_gradient_checkpointing: true
151
+
152
+ # -----------------------------------------------------------------------------
153
+ # Acceleration Configuration
154
+ # -----------------------------------------------------------------------------
155
+ # Hardware acceleration and memory optimization settings.
156
+ acceleration:
157
+ # Mixed precision training mode
158
+ # Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended)
159
+ mixed_precision_mode: "bf16"
160
+
161
+ # Model quantization for reduced memory usage
162
+ # Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto"
163
+ quantization: null
164
+
165
+ # Load text encoder in 8-bit precision to save memory
166
+ # Useful when GPU memory is limited
167
+ load_text_encoder_in_8bit: false
168
+
169
+ # -----------------------------------------------------------------------------
170
+ # Data Configuration
171
+ # -----------------------------------------------------------------------------
172
+ # Specifies the training data location and loading parameters.
173
+ data:
174
+ # Root directory containing preprocessed training data
175
+ # Should contain: latents/, conditions/, and optionally audio_latents/
176
+ preprocessed_data_root: "/path/to/preprocessed/data"
177
+
178
+ # Number of worker processes for data loading
179
+ # Used for parallel data loading to speed up data loading
180
+ num_dataloader_workers: 2
181
+
182
+ # -----------------------------------------------------------------------------
183
+ # Validation Configuration
184
+ # -----------------------------------------------------------------------------
185
+ # Controls validation video generation during training.
186
+ # NOTE: Validation sampling use simplified inference pipelines and prioritizes speed over
187
+ # maximum quality. For production-quality inference, use `packages/ltx-pipelines`.
188
+ validation:
189
+ # Text prompts for validation video generation
190
+ # Provide prompts representative of your training data
191
+ # LTX-2 prefers longer, detailed prompts that describe both visual content and audio
192
+ prompts:
193
+ - "A woman with long brown hair sits at a wooden desk in a cozy home office, typing on a laptop while occasionally glancing at notes beside her. Soft natural light streams through a large window, casting warm shadows across the room. She pauses to take a sip from a ceramic mug, then continues working with focused concentration. The audio captures the gentle clicking of keyboard keys, the soft rustle of papers, and ambient room tone with occasional distant bird chirps from outside."
194
+ - "A chef in a white uniform stands in a professional kitchen, carefully plating a gourmet dish with precise movements. Steam rises from freshly cooked vegetables as he arranges them with tweezers. The stainless steel surfaces gleam under bright overhead lights, and various pots simmer on the stove behind him. The audio features the sizzling of pans, the clinking of utensils against plates, and the ambient hum of kitchen ventilation."
195
+
196
+ # Negative prompt to avoid unwanted artifacts
197
+ negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
198
+
199
+ # Optional: First frame images for image-to-video validation
200
+ # If provided, must have one image per prompt
201
+ images: null
202
+
203
+ # Output video dimensions [width, height, frames]
204
+ # Width and height must be divisible by 32
205
+ # Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...)
206
+ video_dims: [ 576, 576, 89 ]
207
+
208
+ # Frame rate for generated videos
209
+ frame_rate: 25.0
210
+
211
+ # Random seed for reproducible validation outputs
212
+ seed: 42
213
+
214
+ # Number of denoising steps for validation inference
215
+ # Higher values = better quality but slower generation
216
+ inference_steps: 30
217
+
218
+ # Generate validation videos every N training steps
219
+ # Set to null to disable validation during training
220
+ interval: 100
221
+
222
+ # Number of videos to generate per prompt
223
+ videos_per_prompt: 1
224
+
225
+ # Classifier-free guidance scale
226
+ # Higher values = stronger adherence to prompt but may introduce artifacts
227
+ guidance_scale: 4.0
228
+
229
+ # STG (Spatio-Temporal Guidance) parameters for improved video quality
230
+ # STG is combined with CFG for better temporal coherence
231
+ stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG)
232
+ stg_blocks: [29] # Recommended: single block 29
233
+ stg_mode: "stg_av" # "stg_av" perturbs both audio and video, "stg_v" video only
234
+
235
+ # Whether to generate audio in validation samples
236
+ # Independent of training_strategy.with_audio - you can generate audio
237
+ # in validation even when not training the audio branch
238
+ generate_audio: true
239
+
240
+ # Skip validation at the beginning of training (step 0)
241
+ skip_initial_validation: false
242
+
243
+ # -----------------------------------------------------------------------------
244
+ # Checkpoint Configuration
245
+ # -----------------------------------------------------------------------------
246
+ # Controls model checkpoint saving during training.
247
+ checkpoints:
248
+ # Save a checkpoint every N steps
249
+ # Set to null to disable intermediate checkpoints
250
+ interval: 250
251
+
252
+ # Number of most recent checkpoints to keep
253
+ # Set to -1 to keep all checkpoints
254
+ keep_last_n: -1
255
+
256
+ # Precision to use when saving checkpoint weights
257
+ # Options: "bfloat16" (default, smaller files) or "float32" (full precision)
258
+ precision: "bfloat16"
259
+
260
+ # -----------------------------------------------------------------------------
261
+ # Flow Matching Configuration
262
+ # -----------------------------------------------------------------------------
263
+ # Parameters for the flow matching training objective.
264
+ flow_matching:
265
+ # Timestep sampling mode
266
+ # "shifted_logit_normal" is recommended for LTX-2 models
267
+ timestep_sampling_mode: "shifted_logit_normal"
268
+
269
+ # Additional parameters for timestep sampling
270
+ timestep_sampling_params: { }
271
+
272
+ # -----------------------------------------------------------------------------
273
+ # Hugging Face Hub Configuration
274
+ # -----------------------------------------------------------------------------
275
+ # Settings for uploading trained models to the Hugging Face Hub.
276
+ hub:
277
+ # Whether to push the trained model to the Hub
278
+ push_to_hub: false
279
+
280
+ # Repository ID on Hugging Face Hub (e.g., "username/my-lora-model")
281
+ # Required if push_to_hub is true
282
+ hub_model_id: null
283
+
284
+ # -----------------------------------------------------------------------------
285
+ # Weights & Biases Configuration
286
+ # -----------------------------------------------------------------------------
287
+ # Settings for experiment tracking with W&B.
288
+ wandb:
289
+ # Enable W&B logging
290
+ enabled: false
291
+
292
+ # W&B project name
293
+ project: "ltx-2-trainer"
294
+
295
+ # W&B username or team (null uses default account)
296
+ entity: null
297
+
298
+ # Tags to help organize runs
299
+ tags: [ "ltx2", "lora" ]
300
+
301
+ # Log validation videos to W&B
302
+ log_validation_videos: true
303
+
304
+ # -----------------------------------------------------------------------------
305
+ # General Configuration
306
+ # -----------------------------------------------------------------------------
307
+ # Global settings for the training run.
308
+
309
+ # Random seed for reproducibility
310
+ seed: 42
311
+
312
+ # Directory to save outputs (checkpoints, validation videos, logs)
313
+ output_dir: "outputs/ltx2_av_lora"
packages/ltx-trainer/configs/ltx2_av_lora_low_vram.yaml ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # LTX-2 Audio-Video LoRA Training Configuration (Low VRAM)
3
+ # =============================================================================
4
+ #
5
+ # This is a memory-optimized variant of the standard audio-video LoRA config.
6
+ # It uses 8-bit optimizer, int8 quantization, and reduced LoRA rank to minimize
7
+ # GPU memory usage while maintaining good training quality.
8
+ #
9
+ # Memory optimizations applied:
10
+ # - 8-bit AdamW optimizer (reduces optimizer state memory by ~75%)
11
+ # - INT8 model quantization (reduces model memory by ~50%)
12
+ # - Lower LoRA rank (16 vs 32, reduces trainable parameters)
13
+ # - Gradient checkpointing enabled
14
+ #
15
+ # Recommended for GPUs with 32GB VRAM (e.g., RTX 5090).
16
+ #
17
+ # Use this configuration when you want to:
18
+ # - Fine-tune LTX-2 on your own video dataset
19
+ # - Train with or without audio generation
20
+ # - Create custom video generation styles or audiovisual concepts
21
+ #
22
+ # Dataset structure for text-to-video training:
23
+ # preprocessed_data_root/
24
+ # ├── latents/ # Video latents (VAE-encoded videos)
25
+ # ├── conditions/ # Text embeddings for each video
26
+ # └── audio_latents/ # Audio latents (only if with_audio: true)
27
+ #
28
+ # =============================================================================
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # Model Configuration
32
+ # -----------------------------------------------------------------------------
33
+ # Specifies the base model to fine-tune and the training mode.
34
+ model:
35
+ # Path to the LTX-2 model checkpoint (.safetensors file)
36
+ # This should be a local path to your downloaded model
37
+ model_path: "path/to/ltx-2-model.safetensors"
38
+
39
+ # Path to the text encoder model directory
40
+ # For LTX-2, this is typically the Gemma-based text encoder
41
+ text_encoder_path: "path/to/gemma-text-encoder"
42
+
43
+ # Training mode: "lora" for efficient adapter training, "full" for full fine-tuning
44
+ # LoRA is recommended for most use cases (faster, less memory, prevents overfitting)
45
+ training_mode: "lora"
46
+
47
+ # Optional: Path to resume training from a checkpoint
48
+ # Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint)
49
+ load_checkpoint: null
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # LoRA Configuration
53
+ # -----------------------------------------------------------------------------
54
+ # Controls the Low-Rank Adaptation parameters for efficient fine-tuning.
55
+ # Using a lower rank (16) to reduce trainable parameters and memory usage.
56
+ # This still provides good capacity for many fine-tuning tasks.
57
+ lora:
58
+ # Rank of the LoRA matrices (higher = more capacity but more parameters)
59
+ # Typical values: 8, 16, 32, 64. Using 16 for low VRAM configuration.
60
+ rank: 16
61
+
62
+ # Alpha scaling factor (usually set equal to rank)
63
+ # The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0
64
+ alpha: 16
65
+
66
+ # Dropout probability for LoRA layers (0.0 = no dropout)
67
+ # Can help with regularization if overfitting occurs
68
+ dropout: 0.0
69
+
70
+ # Which transformer modules to apply LoRA to
71
+ # The LTX-2 transformer has separate attention and FFN blocks for video and audio:
72
+ #
73
+ # VIDEO MODULES:
74
+ # - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention)
75
+ # - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text)
76
+ # - ff.net.0.proj, ff.net.2 (video feed-forward)
77
+ #
78
+ # AUDIO MODULES:
79
+ # - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention)
80
+ # - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text)
81
+ # - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward)
82
+ #
83
+ # AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction):
84
+ # - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0
85
+ # (Q from video, K/V from audio - allows video to attend to audio features)
86
+ # - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0
87
+ # (Q from audio, K/V from video - allows audio to attend to video features)
88
+ #
89
+ # Using short patterns like "to_k" matches ALL attention modules (video, audio, and cross-modal).
90
+ # For audio-video training, this is the recommended approach.
91
+ target_modules:
92
+ # Attention layers (matches both video and audio branches)
93
+ - "to_k"
94
+ - "to_q"
95
+ - "to_v"
96
+ - "to_out.0"
97
+ # Uncomment below to also train feed-forward layers (can increase the LoRA's capacity):
98
+ # - "ff.net.0.proj"
99
+ # - "ff.net.2"
100
+ # - "audio_ff.net.0.proj"
101
+ # - "audio_ff.net.2"
102
+
103
+ # -----------------------------------------------------------------------------
104
+ # Training Strategy Configuration
105
+ # -----------------------------------------------------------------------------
106
+ # Defines the text-to-video training approach.
107
+ training_strategy:
108
+ # Strategy name: "text_to_video" for standard text-to-video training
109
+ name: "text_to_video"
110
+
111
+ # Probability of conditioning on the first frame during training
112
+ # Higher values train the model to perform better in image-to-video (I2V) mode,
113
+ # where a clean first frame is provided and the model generates the rest of the video
114
+ # Increase this value to train the model to perform better in image-to-video (I2V) mode
115
+ first_frame_conditioning_p: 0.5
116
+
117
+ # Enable joint audio-video training
118
+ # Set to true if your dataset includes audio and you want to train the audio branch
119
+ with_audio: true
120
+
121
+ # Directory name (within preprocessed_data_root) containing audio latents
122
+ # Only used when with_audio is true
123
+ audio_latents_dir: "audio_latents"
124
+
125
+ # -----------------------------------------------------------------------------
126
+ # Optimization Configuration
127
+ # -----------------------------------------------------------------------------
128
+ # Controls the training optimization parameters.
129
+ optimization:
130
+ # Learning rate for the optimizer
131
+ # Typical range for LoRA: 1e-5 to 1e-4
132
+ learning_rate: 1e-4
133
+
134
+ # Total number of training steps
135
+ steps: 2000
136
+
137
+ # Batch size per GPU
138
+ # Reduce if running out of memory
139
+ batch_size: 1
140
+
141
+ # Number of gradient accumulation steps
142
+ # Effective batch size = batch_size * gradient_accumulation_steps * num_gpus
143
+ gradient_accumulation_steps: 1
144
+
145
+ # Maximum gradient norm for clipping (helps training stability)
146
+ max_grad_norm: 1.0
147
+
148
+ # Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient)
149
+ # Using 8-bit AdamW to reduce optimizer state memory by ~75%
150
+ optimizer_type: "adamw8bit"
151
+
152
+ # Learning rate scheduler type
153
+ # Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial"
154
+ scheduler_type: "linear"
155
+
156
+ # Additional scheduler parameters (depends on scheduler_type)
157
+ scheduler_params: { }
158
+
159
+ # Enable gradient checkpointing to reduce memory usage
160
+ # Recommended for training with limited GPU memory
161
+ enable_gradient_checkpointing: true
162
+
163
+ # -----------------------------------------------------------------------------
164
+ # Acceleration Configuration
165
+ # -----------------------------------------------------------------------------
166
+ # Hardware acceleration and memory optimization settings.
167
+ acceleration:
168
+ # Mixed precision training mode
169
+ # Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended)
170
+ mixed_precision_mode: "bf16"
171
+
172
+ # Model quantization for reduced memory usage
173
+ # Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto"
174
+ # Using INT8 quantization to reduce base model memory consumption by ~50%
175
+ quantization: "int8-quanto"
176
+
177
+ # Load text encoder in 8-bit precision to save memory
178
+ # Useful when GPU memory is limited
179
+ load_text_encoder_in_8bit: true
180
+
181
+ # -----------------------------------------------------------------------------
182
+ # Data Configuration
183
+ # -----------------------------------------------------------------------------
184
+ # Specifies the training data location and loading parameters.
185
+ data:
186
+ # Root directory containing preprocessed training data
187
+ # Should contain: latents/, conditions/, and optionally audio_latents/
188
+ preprocessed_data_root: "/path/to/preprocessed/data"
189
+
190
+ # Number of worker processes for data loading
191
+ # Used for parallel data loading to speed up data loading
192
+ num_dataloader_workers: 2
193
+
194
+ # -----------------------------------------------------------------------------
195
+ # Validation Configuration
196
+ # -----------------------------------------------------------------------------
197
+ # Controls validation video generation during training.
198
+ # NOTE: Validation sampling use simplified inference pipelines and prioritizes speed over
199
+ # maximum quality. For production-quality inference, use `packages/ltx-pipelines`.
200
+ validation:
201
+ # Text prompts for validation video generation
202
+ # Provide prompts representative of your training data
203
+ # LTX-2 prefers longer, detailed prompts that describe both visual content and audio
204
+ prompts:
205
+ - "A woman with long brown hair sits at a wooden desk in a cozy home office, typing on a laptop while occasionally glancing at notes beside her. Soft natural light streams through a large window, casting warm shadows across the room. She pauses to take a sip from a ceramic mug, then continues working with focused concentration. The audio captures the gentle clicking of keyboard keys, the soft rustle of papers, and ambient room tone with occasional distant bird chirps from outside."
206
+ - "A chef in a white uniform stands in a professional kitchen, carefully plating a gourmet dish with precise movements. Steam rises from freshly cooked vegetables as he arranges them with tweezers. The stainless steel surfaces gleam under bright overhead lights, and various pots simmer on the stove behind him. The audio features the sizzling of pans, the clinking of utensils against plates, and the ambient hum of kitchen ventilation."
207
+
208
+ # Negative prompt to avoid unwanted artifacts
209
+ negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
210
+
211
+ # Optional: First frame images for image-to-video validation
212
+ # If provided, must have one image per prompt
213
+ images: null
214
+
215
+ # Output video dimensions [width, height, frames]
216
+ # Width and height must be divisible by 32
217
+ # Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...)
218
+ video_dims: [ 576, 576, 49 ]
219
+
220
+ # Frame rate for generated videos
221
+ frame_rate: 25.0
222
+
223
+ # Random seed for reproducible validation outputs
224
+ seed: 42
225
+
226
+ # Number of denoising steps for validation inference
227
+ # Higher values = better quality but slower generation
228
+ inference_steps: 30
229
+
230
+ # Generate validation videos every N training steps
231
+ # Set to null to disable validation during training
232
+ interval: 100
233
+
234
+ # Number of videos to generate per prompt
235
+ videos_per_prompt: 1
236
+
237
+ # Classifier-free guidance scale
238
+ # Higher values = stronger adherence to prompt but may introduce artifacts
239
+ guidance_scale: 4.0
240
+
241
+ # STG (Spatio-Temporal Guidance) parameters for improved video quality
242
+ # STG is combined with CFG for better temporal coherence
243
+ stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG)
244
+ stg_blocks: [ 29 ] # Recommended: single block 29
245
+ stg_mode: "stg_av" # "stg_av" perturbs both audio and video, "stg_v" video only
246
+
247
+ # Whether to generate audio in validation samples
248
+ # Independent of training_strategy.with_audio - you can generate audio
249
+ # in validation even when not training the audio branch
250
+ generate_audio: true
251
+
252
+ # Skip validation at the beginning of training (step 0)
253
+ skip_initial_validation: false
254
+
255
+ # -----------------------------------------------------------------------------
256
+ # Checkpoint Configuration
257
+ # -----------------------------------------------------------------------------
258
+ # Controls model checkpoint saving during training.
259
+ checkpoints:
260
+ # Save a checkpoint every N steps
261
+ # Set to null to disable intermediate checkpoints
262
+ interval: 250
263
+
264
+ # Number of most recent checkpoints to keep
265
+ # Set to -1 to keep all checkpoints
266
+ keep_last_n: -1
267
+
268
+ # Precision to use when saving checkpoint weights
269
+ # Options: "bfloat16" (default, smaller files) or "float32" (full precision)
270
+ precision: "bfloat16"
271
+
272
+ # -----------------------------------------------------------------------------
273
+ # Flow Matching Configuration
274
+ # -----------------------------------------------------------------------------
275
+ # Parameters for the flow matching training objective.
276
+ flow_matching:
277
+ # Timestep sampling mode
278
+ # "shifted_logit_normal" is recommended for LTX-2 models
279
+ timestep_sampling_mode: "shifted_logit_normal"
280
+
281
+ # Additional parameters for timestep sampling
282
+ timestep_sampling_params: { }
283
+
284
+ # -----------------------------------------------------------------------------
285
+ # Hugging Face Hub Configuration
286
+ # -----------------------------------------------------------------------------
287
+ # Settings for uploading trained models to the Hugging Face Hub.
288
+ hub:
289
+ # Whether to push the trained model to the Hub
290
+ push_to_hub: false
291
+
292
+ # Repository ID on Hugging Face Hub (e.g., "username/my-lora-model")
293
+ # Required if push_to_hub is true
294
+ hub_model_id: null
295
+
296
+ # -----------------------------------------------------------------------------
297
+ # Weights & Biases Configuration
298
+ # -----------------------------------------------------------------------------
299
+ # Settings for experiment tracking with W&B.
300
+ wandb:
301
+ # Enable W&B logging
302
+ enabled: false
303
+
304
+ # W&B project name
305
+ project: "ltx-2-trainer"
306
+
307
+ # W&B username or team (null uses default account)
308
+ entity: null
309
+
310
+ # Tags to help organize runs
311
+ tags: [ "ltx2", "lora" ]
312
+
313
+ # Log validation videos to W&B
314
+ log_validation_videos: true
315
+
316
+ # -----------------------------------------------------------------------------
317
+ # General Configuration
318
+ # -----------------------------------------------------------------------------
319
+ # Global settings for the training run.
320
+
321
+ # Random seed for reproducibility
322
+ seed: 42
323
+
324
+ # Directory to save outputs (checkpoints, validation videos, logs)
325
+ output_dir: "outputs/ltx2_av_lora"
packages/ltx-trainer/configs/ltx2_v2v_ic_lora.yaml ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # LTX-2 Video-to-Video (IC-LoRA) Training Configuration
3
+ # =============================================================================
4
+ #
5
+ # This configuration is for training In-Context LoRA (IC-LoRA) adapters that
6
+ # enable video-to-video transformations. IC-LoRA learns to apply visual
7
+ # transformations (e.g., depth-to-video, pose control, style transfer, etc.)
8
+ # by conditioning on reference videos.
9
+ #
10
+ # Key differences from text-to-video LoRA:
11
+ # - Uses reference videos as conditioning input alongside text prompts
12
+ # - Requires preprocessed reference latents in addition to target latents
13
+ # - Validation requires reference videos to demonstrate the transformation
14
+ #
15
+ # Dataset structure for IC-LoRA training:
16
+ # preprocessed_data_root/
17
+ # ├── latents/ # Target video latents (what the model learns to generate)
18
+ # ├── conditions/ # Text embeddings for each video
19
+ # └── reference_latents/ # Reference video latents (conditioning input)
20
+ #
21
+ # =============================================================================
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Model Configuration
25
+ # -----------------------------------------------------------------------------
26
+ # Specifies the base model to fine-tune and the training mode.
27
+ model:
28
+ # Path to the LTX-2 model checkpoint (.safetensors file)
29
+ # This should be a local path to your downloaded model
30
+ model_path: "path/to/ltx-2-model.safetensors"
31
+
32
+ # Path to the text encoder model directory
33
+ # For LTX-2, this is typically the Gemma-based text encoder
34
+ text_encoder_path: "path/to/gemma-text-encoder"
35
+
36
+ # Training mode: "lora" for efficient adapter training, "full" for full fine-tuning
37
+ # Note: video_to_video strategy requires "lora" mode
38
+ training_mode: "lora"
39
+
40
+ # Optional: Path to resume training from a checkpoint
41
+ # Can be a checkpoint file (.safetensors) or directory (uses latest checkpoint)
42
+ load_checkpoint: null
43
+
44
+ # -----------------------------------------------------------------------------
45
+ # LoRA Configuration
46
+ # -----------------------------------------------------------------------------
47
+ # Controls the Low-Rank Adaptation parameters for efficient fine-tuning.
48
+ lora:
49
+ # Rank of the LoRA matrices (higher = more capacity but more parameters)
50
+ # Typical values: 8, 16, 32, 64. Start with 16-32 for IC-LoRA.
51
+ rank: 32
52
+
53
+ # Alpha scaling factor (usually set equal to rank)
54
+ # The effective scaling is alpha/rank, so alpha=rank means scaling of 1.0
55
+ alpha: 32
56
+
57
+ # Dropout probability for LoRA layers (0.0 = no dropout)
58
+ # Can help with regularization if overfitting occurs
59
+ dropout: 0.0
60
+
61
+ # Which transformer modules to apply LoRA to
62
+ # The LTX-2 transformer has separate attention and FFN blocks for video and audio:
63
+ #
64
+ # VIDEO MODULES:
65
+ # - attn1.to_k, attn1.to_q, attn1.to_v, attn1.to_out.0 (video self-attention)
66
+ # - attn2.to_k, attn2.to_q, attn2.to_v, attn2.to_out.0 (video cross-attention to text)
67
+ # - ff.net.0.proj, ff.net.2 (video feed-forward)
68
+ #
69
+ # AUDIO MODULES (not used for video-only IC-LoRA):
70
+ # - audio_attn1.to_k, audio_attn1.to_q, audio_attn1.to_v, audio_attn1.to_out.0 (audio self-attention)
71
+ # - audio_attn2.to_k, audio_attn2.to_q, audio_attn2.to_v, audio_attn2.to_out.0 (audio cross-attention to text)
72
+ # - audio_ff.net.0.proj, audio_ff.net.2 (audio feed-forward)
73
+ #
74
+ # AUDIO-VIDEO CROSS-ATTENTION MODULES (for cross-modal interaction, not used for video-only IC-LoRA):
75
+ # - audio_to_video_attn.to_k, audio_to_video_attn.to_q, audio_to_video_attn.to_v, audio_to_video_attn.to_out.0
76
+ # (Q from video, K/V from audio - allows video to attend to audio features)
77
+ # - video_to_audio_attn.to_k, video_to_audio_attn.to_q, video_to_audio_attn.to_v, video_to_audio_attn.to_out.0
78
+ # (Q from audio, K/V from video - allows audio to attend to video features)
79
+ #
80
+ # For IC-LoRA (video-only), we explicitly target video modules.
81
+ # Including FFN layers often improves transformation quality.
82
+ target_modules:
83
+ # Video self-attention
84
+ - "attn1.to_k"
85
+ - "attn1.to_q"
86
+ - "attn1.to_v"
87
+ - "attn1.to_out.0"
88
+ # Video cross-attention
89
+ - "attn2.to_k"
90
+ - "attn2.to_q"
91
+ - "attn2.to_v"
92
+ - "attn2.to_out.0"
93
+ # Video feed-forward (often improves transformation quality)
94
+ - "ff.net.0.proj"
95
+ - "ff.net.2"
96
+
97
+ # -----------------------------------------------------------------------------
98
+ # Training Strategy Configuration
99
+ # -----------------------------------------------------------------------------
100
+ # Defines the video-to-video (IC-LoRA) training approach.
101
+ training_strategy:
102
+ # Strategy name: "video_to_video" for IC-LoRA training
103
+ name: "video_to_video"
104
+
105
+ # Probability of conditioning on the first frame during training
106
+ # Higher values train the model to perform better in image-to-video (I2V) mode,
107
+ # where a clean first frame is provided and the model generates the rest of the video
108
+ # Increase this value to train the model to perform better in image-to-video (I2V) mode
109
+ first_frame_conditioning_p: 0.2
110
+
111
+ # Directory name (within preprocessed_data_root) containing reference video latents
112
+ # These are the conditioning inputs that guide the transformation
113
+ reference_latents_dir: "reference_latents"
114
+
115
+ # -----------------------------------------------------------------------------
116
+ # Optimization Configuration
117
+ # -----------------------------------------------------------------------------
118
+ # Controls the training optimization parameters.
119
+ optimization:
120
+ # Learning rate for the optimizer
121
+ # Typical range for LoRA: 1e-5 to 1e-4
122
+ learning_rate: 2e-4
123
+
124
+ # Total number of training steps
125
+ steps: 3000
126
+
127
+ # Batch size per GPU
128
+ # Reduce if running out of memory
129
+ batch_size: 1
130
+
131
+ # Number of gradient accumulation steps
132
+ # Effective batch size = batch_size * gradient_accumulation_steps * num_gpus
133
+ gradient_accumulation_steps: 1
134
+
135
+ # Maximum gradient norm for clipping (helps training stability)
136
+ max_grad_norm: 1.0
137
+
138
+ # Optimizer type: "adamw" (standard) or "adamw8bit" (memory-efficient)
139
+ optimizer_type: "adamw"
140
+
141
+ # Learning rate scheduler type
142
+ # Options: "constant", "linear", "cosine", "cosine_with_restarts", "polynomial"
143
+ scheduler_type: "linear"
144
+
145
+ # Additional scheduler parameters (depends on scheduler_type)
146
+ scheduler_params: { }
147
+
148
+ # Enable gradient checkpointing to reduce memory usage
149
+ # Recommended for training with limited GPU memory
150
+ enable_gradient_checkpointing: true
151
+
152
+ # -----------------------------------------------------------------------------
153
+ # Acceleration Configuration
154
+ # -----------------------------------------------------------------------------
155
+ # Hardware acceleration and memory optimization settings.
156
+ acceleration:
157
+ # Mixed precision training mode
158
+ # Options: "no" (fp32), "fp16" (half precision), "bf16" (bfloat16, recommended)
159
+ mixed_precision_mode: "bf16"
160
+
161
+ # Model quantization for reduced memory usage
162
+ # Options: null (none), "int8-quanto", "int4-quanto", "int2-quanto", "fp8-quanto", "fp8uz-quanto"
163
+ quantization: null
164
+
165
+ # Load text encoder in 8-bit precision to save memory
166
+ # Useful when GPU memory is limited
167
+ load_text_encoder_in_8bit: false
168
+
169
+ # -----------------------------------------------------------------------------
170
+ # Data Configuration
171
+ # -----------------------------------------------------------------------------
172
+ # Specifies the training data location and loading parameters.
173
+ data:
174
+ # Root directory containing preprocessed training data
175
+ # Should contain: latents/, conditions/, and reference_latents/ subdirectories
176
+ preprocessed_data_root: "/path/to/preprocessed/data"
177
+
178
+ # Number of worker processes for data loading
179
+ # Used for parallel data loading to speed up data loading
180
+ num_dataloader_workers: 2
181
+
182
+ # -----------------------------------------------------------------------------
183
+ # Validation Configuration
184
+ # -----------------------------------------------------------------------------
185
+ # Controls validation video generation during training.
186
+ # NOTE: Validation sampling use simplified inference pipelines and prioritizes speed over
187
+ # maximum quality. For production-quality inference, use `packages/ltx-pipelines`.
188
+ validation:
189
+ # Text prompts for validation video generation
190
+ # Provide prompts representative of your training data
191
+ # LTX-2 prefers longer, detailed prompts that describe both visual content and audio
192
+ prompts:
193
+ - "A man in a casual blue jacket walks along a winding path through a lush green park on a bright sunny afternoon. Tall oak trees line the pathway, their leaves rustling gently in the breeze. Dappled sunlight creates shifting patterns on the ground as he strolls at a relaxed pace, occasionally looking up at the scenery around him. The audio captures footsteps on gravel, birds singing in the trees, distant children playing, and the soft whisper of wind through the foliage."
194
+ - "A fluffy orange tabby cat sits perfectly still on a wooden windowsill, its green eyes intently tracking small birds hopping on a branch just outside the glass. The cat's ears twitch and rotate, following every movement. Warm afternoon light illuminates its fur, creating a soft golden glow. Behind the cat, a cozy living room with a bookshelf and houseplants is visible. The audio features gentle purring, occasional soft meows, muffled bird chirps through the window, and quiet ambient room sounds."
195
+
196
+ # Reference videos for validation (REQUIRED for video_to_video strategy)
197
+ # Must provide one reference video per prompt
198
+ # These are the conditioning inputs for generating validation outputs
199
+ reference_videos:
200
+ - "/path/to/reference_video_1.mp4"
201
+ - "/path/to/reference_video_2.mp4"
202
+
203
+ # Downscale factor for reference videos (for efficient IC-LoRA training)
204
+ # When > 1, reference videos are processed at 1/n resolution
205
+ # Must match the --reference-downscale-factor used during dataset preprocessing
206
+ # Examples: 1 = same resolution, 2 = half resolution (384x384 ref for 768x768 target)
207
+ reference_downscale_factor: 1
208
+
209
+ # Negative prompt to avoid unwanted artifacts
210
+ negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
211
+
212
+ # Optional: First frame images for additional conditioning
213
+ # If provided, must have one image per prompt
214
+ images: null
215
+
216
+ # Output video dimensions [width, height, frames]
217
+ # Width and height must be divisible by 32
218
+ # Frames must satisfy: frames % 8 == 1 (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, ...)
219
+ video_dims: [ 512, 512, 81 ]
220
+
221
+ # Frame rate for generated videos
222
+ frame_rate: 25.0
223
+
224
+ # Random seed for reproducible validation outputs
225
+ seed: 42
226
+
227
+ # Number of denoising steps for validation inference
228
+ # Higher values = better quality but slower generation
229
+ inference_steps: 30
230
+
231
+ # Generate validation videos every N training steps
232
+ # Set to null to disable validation during training
233
+ interval: 100
234
+
235
+ # Number of videos to generate per prompt
236
+ videos_per_prompt: 1
237
+
238
+ # Classifier-free guidance scale
239
+ # Higher values = stronger adherence to prompt but may introduce artifacts
240
+ guidance_scale: 4.0
241
+
242
+ # STG (Spatio-Temporal Guidance) parameters for improved video quality
243
+ # STG is combined with CFG for better temporal coherence
244
+ stg_scale: 1.0 # Recommended: 1.0 (0.0 disables STG)
245
+ stg_blocks: [29] # Recommended: single block 29
246
+ stg_mode: "stg_v" # "stg_v" for video-only (no audio training)
247
+
248
+ # Whether to generate audio in validation samples
249
+ # Can be enabled even when not training the audio branch
250
+ generate_audio: false
251
+
252
+ # Skip validation at the beginning of training (step 0)
253
+ skip_initial_validation: false
254
+
255
+ # Concatenate reference video side-by-side with generated output
256
+ # Useful for visually comparing the transformation quality
257
+ include_reference_in_output: true
258
+
259
+ # -----------------------------------------------------------------------------
260
+ # Checkpoint Configuration
261
+ # -----------------------------------------------------------------------------
262
+ # Controls model checkpoint saving during training.
263
+ checkpoints:
264
+ # Save a checkpoint every N steps
265
+ # Set to null to disable intermediate checkpoints
266
+ interval: 250
267
+
268
+ # Number of most recent checkpoints to keep
269
+ # Set to -1 to keep all checkpoints
270
+ keep_last_n: 3
271
+
272
+ # Precision to use when saving checkpoint weights
273
+ # Options: "bfloat16" (default, smaller files) or "float32" (full precision)
274
+ precision: "bfloat16"
275
+
276
+ # -----------------------------------------------------------------------------
277
+ # Flow Matching Configuration
278
+ # -----------------------------------------------------------------------------
279
+ # Parameters for the flow matching training objective.
280
+ flow_matching:
281
+ # Timestep sampling mode
282
+ # "shifted_logit_normal" is recommended for LTX-2 models
283
+ timestep_sampling_mode: "shifted_logit_normal"
284
+
285
+ # Additional parameters for timestep sampling
286
+ timestep_sampling_params: { }
287
+
288
+ # -----------------------------------------------------------------------------
289
+ # Hugging Face Hub Configuration
290
+ # -----------------------------------------------------------------------------
291
+ # Settings for uploading trained models to the Hugging Face Hub.
292
+ hub:
293
+ # Whether to push the trained model to the Hub
294
+ push_to_hub: false
295
+
296
+ # Repository ID on Hugging Face Hub (e.g., "username/my-ic-lora-model")
297
+ # Required if push_to_hub is true
298
+ hub_model_id: null
299
+
300
+ # -----------------------------------------------------------------------------
301
+ # Weights & Biases Configuration
302
+ # -----------------------------------------------------------------------------
303
+ # Settings for experiment tracking with W&B.
304
+ wandb:
305
+ # Enable W&B logging
306
+ enabled: false
307
+
308
+ # W&B project name
309
+ project: "ltx-2-trainer"
310
+
311
+ # W&B username or team (null uses default account)
312
+ entity: null
313
+
314
+ # Tags to help organize runs
315
+ tags: [ "ltx2", "ic-lora", "video-to-video" ]
316
+
317
+ # Log validation videos to W&B
318
+ log_validation_videos: true
319
+
320
+ # -----------------------------------------------------------------------------
321
+ # General Configuration
322
+ # -----------------------------------------------------------------------------
323
+ # Global settings for the training run.
324
+
325
+ # Random seed for reproducibility
326
+ seed: 42
327
+
328
+ # Directory to save outputs (checkpoints, validation videos, logs)
329
+ output_dir: "outputs/ltx2_v2v_ic_lora"
packages/ltx-trainer/docs/configuration-reference.md ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration Reference
2
+
3
+ The trainer uses structured Pydantic models for configuration, making it easy to customize training parameters.
4
+ This guide covers all available configuration options and their usage.
5
+
6
+ ## 📋 Overview
7
+
8
+ The main configuration class is [`LtxTrainerConfig`](../src/ltx_trainer/config.py), which includes the following
9
+ sub-configurations:
10
+
11
+ - **ModelConfig**: Base model and training mode settings
12
+ - **LoraConfig**: LoRA training parameters
13
+ - **TrainingStrategyConfig**: Training strategy settings (text-to-video or video-to-video)
14
+ - **OptimizationConfig**: Learning rate, batch sizes, and scheduler settings
15
+ - **AccelerationConfig**: Mixed precision and quantization settings
16
+ - **DataConfig**: Data loading parameters
17
+ - **ValidationConfig**: Validation and inference settings
18
+ - **CheckpointsConfig**: Checkpoint saving frequency and retention settings
19
+ - **HubConfig**: Hugging Face Hub integration settings
20
+ - **WandbConfig**: Weights & Biases logging settings
21
+ - **FlowMatchingConfig**: Timestep sampling parameters
22
+
23
+ ## 📄 Example Configuration Files
24
+
25
+ Check out our example configurations in the `configs` directory:
26
+
27
+ - 📄 [Audio-Video LoRA Training](../configs/ltx2_av_lora.yaml) - Joint audio-video generation training
28
+ - 📄 [Audio-Video LoRA Training (Low VRAM)](../configs/ltx2_av_lora_low_vram.yaml) - Memory-optimized config for 32GB
29
+ GPUs (uses 8-bit optimizer, INT8 quantization, and reduced LoRA rank)
30
+ - 📄 [IC-LoRA Training](../configs/ltx2_v2v_ic_lora.yaml) - Video-to-video transformation training
31
+
32
+ ## ⚙️ Configuration Sections
33
+
34
+ ### ModelConfig
35
+
36
+ Controls the base model and training mode settings.
37
+
38
+ ```yaml
39
+ model:
40
+ model_path: "/path/to/ltx-2-model.safetensors" # Local path to model checkpoint
41
+ text_encoder_path: "/path/to/gemma-model" # Path to Gemma text encoder directory
42
+ training_mode: "lora" # "lora" or "full"
43
+ load_checkpoint: null # Path to checkpoint to resume from
44
+ ```
45
+
46
+ **Key parameters:**
47
+
48
+ | Parameter | Description |
49
+ |---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------|
50
+ | `model_path` | **Required.** Local path to the LTX-2 model checkpoint (`.safetensors` file). URLs are not supported. |
51
+ | `text_encoder_path` | **Required.** Path to the Gemma text encoder model directory. Download from [HuggingFace](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/). |
52
+ | `training_mode` | Training approach - `"lora"` for LoRA training or `"full"` for full-rank fine-tuning. |
53
+ | `load_checkpoint` | Optional path to resume training from a checkpoint file or directory. |
54
+
55
+ > [!NOTE]
56
+ > LTX-2 requires both a model checkpoint and a Gemma text encoder. Both must be local paths.
57
+
58
+ ### LoraConfig
59
+
60
+ LoRA-specific fine-tuning parameters (only used when `training_mode: "lora"`).
61
+
62
+ ```yaml
63
+ lora:
64
+ rank: 32 # LoRA rank (higher = more parameters)
65
+ alpha: 32 # LoRA alpha scaling factor
66
+ dropout: 0.0 # Dropout probability (0.0-1.0)
67
+ target_modules: # Modules to apply LoRA to
68
+ - "to_k"
69
+ - "to_q"
70
+ - "to_v"
71
+ - "to_out.0"
72
+ ```
73
+
74
+ **Key parameters:**
75
+
76
+ | Parameter | Description |
77
+ |------------------|---------------------------------------------------------------------------------|
78
+ | `rank` | LoRA rank - higher values mean more trainable parameters (typical range: 8-128) |
79
+ | `alpha` | Alpha scaling factor - typically set equal to rank |
80
+ | `dropout` | Dropout probability for regularization |
81
+ | `target_modules` | List of transformer modules to apply LoRA adapters to (see below) |
82
+
83
+ #### Understanding Target Modules
84
+
85
+ The LTX-2 transformer has separate attention and feed-forward blocks for video and audio, as well as cross-attention
86
+ modules that enable the two modalities to exchange information. Choosing the right `target_modules` is critical for
87
+ achieving good results, especially when training with audio.
88
+
89
+ **Video-only modules:**
90
+
91
+ | Module Pattern | Description |
92
+ |------------------------------------------------------------|---------------------------------|
93
+ | `attn1.to_k`, `attn1.to_q`, `attn1.to_v`, `attn1.to_out.0` | Video self-attention |
94
+ | `attn2.to_k`, `attn2.to_q`, `attn2.to_v`, `attn2.to_out.0` | Video cross-attention (to text) |
95
+ | `ff.net.0.proj`, `ff.net.2` | Video feed-forward network |
96
+
97
+ **Audio-only modules:**
98
+
99
+ | Module Pattern | Description |
100
+ |------------------------------------------------------------------------------------|---------------------------------|
101
+ | `audio_attn1.to_k`, `audio_attn1.to_q`, `audio_attn1.to_v`, `audio_attn1.to_out.0` | Audio self-attention |
102
+ | `audio_attn2.to_k`, `audio_attn2.to_q`, `audio_attn2.to_v`, `audio_attn2.to_out.0` | Audio cross-attention (to text) |
103
+ | `audio_ff.net.0.proj`, `audio_ff.net.2` | Audio feed-forward network |
104
+
105
+ **Audio-video cross-attention modules:**
106
+
107
+ These modules enable bidirectional information flow between the audio and video modalities:
108
+
109
+ | Module Pattern | Description |
110
+ |--------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------|
111
+ | `audio_to_video_attn.to_k`, `audio_to_video_attn.to_q`, `audio_to_video_attn.to_v`, `audio_to_video_attn.to_out.0` | Video attends to audio (Q from video, K/V from audio) |
112
+ | `video_to_audio_attn.to_k`, `video_to_audio_attn.to_q`, `video_to_audio_attn.to_v`, `video_to_audio_attn.to_out.0` | Audio attends to video (Q from audio, K/V from video) |
113
+
114
+ **Recommended configurations:**
115
+
116
+ For **video-only training**, target the video attention layers:
117
+
118
+ ```yaml
119
+ target_modules:
120
+ - "attn1.to_k"
121
+ - "attn1.to_q"
122
+ - "attn1.to_v"
123
+ - "attn1.to_out.0"
124
+ - "attn2.to_k"
125
+ - "attn2.to_q"
126
+ - "attn2.to_v"
127
+ - "attn2.to_out.0"
128
+ ```
129
+
130
+ For **audio-video training**, use patterns that match both branches:
131
+
132
+ ```yaml
133
+ target_modules:
134
+ - "to_k"
135
+ - "to_q"
136
+ - "to_v"
137
+ - "to_out.0"
138
+ ```
139
+
140
+ > [!NOTE]
141
+ > Using shorter patterns like `"to_k"` will match all attention modules including `attn1.to_k`, `audio_attn1.to_k`,
142
+ > `audio_to_video_attn.to_k`, and `video_to_audio_attn.to_k`, effectively training video, audio, and cross-modal
143
+ > attention branches together.
144
+
145
+ > [!TIP]
146
+ > You can also target the feed-forward (FFN) modules (`ff.net.0.proj`, `ff.net.2` for video,
147
+ > `audio_ff.net.0.proj`, `audio_ff.net.2` for audio) to increase the LoRA's capacity and potentially
148
+ > help it capture the target distribution better.
149
+
150
+ ### TrainingStrategyConfig
151
+
152
+ Configures the training strategy. The trainer includes two built-in strategies described below.
153
+ For custom use cases, see [Implementing Custom Training Strategies](custom-training-strategies.md).
154
+
155
+ #### Text-to-Video Strategy
156
+
157
+ ```yaml
158
+ training_strategy:
159
+ name: "text_to_video"
160
+ first_frame_conditioning_p: 0.1 # Probability of first-frame conditioning
161
+ with_audio: false # Enable joint audio-video training
162
+ audio_latents_dir: "audio_latents" # Directory for audio latents (when with_audio: true)
163
+ ```
164
+
165
+ #### Video-to-Video Strategy (IC-LoRA)
166
+
167
+ ```yaml
168
+ training_strategy:
169
+ name: "video_to_video"
170
+ first_frame_conditioning_p: 0.1
171
+ reference_latents_dir: "reference_latents" # Directory for reference video latents
172
+ ```
173
+
174
+ **Key parameters:**
175
+
176
+ | Parameter | Description |
177
+ |------------------------------|------------------------------------------------------------------|
178
+ | `name` | Strategy type: `"text_to_video"` or `"video_to_video"` |
179
+ | `first_frame_conditioning_p` | Probability of using first frame as conditioning (0.0-1.0) |
180
+ | `with_audio` | (text_to_video only) Enable joint audio-video training |
181
+ | `audio_latents_dir` | (text_to_video only) Directory name for audio latents |
182
+ | `reference_latents_dir` | (video_to_video only) Directory name for reference video latents |
183
+
184
+ ### OptimizationConfig
185
+
186
+ Training optimization parameters including learning rates, batch sizes, and schedulers.
187
+
188
+ ```yaml
189
+ optimization:
190
+ learning_rate: 1e-4 # Learning rate
191
+ steps: 2000 # Total training steps
192
+ batch_size: 1 # Batch size per GPU
193
+ gradient_accumulation_steps: 1 # Steps to accumulate gradients
194
+ max_grad_norm: 1.0 # Gradient clipping threshold
195
+ optimizer_type: "adamw" # "adamw" or "adamw8bit"
196
+ scheduler_type: "linear" # Scheduler type
197
+ scheduler_params: { } # Additional scheduler parameters
198
+ enable_gradient_checkpointing: true # Memory optimization
199
+ ```
200
+
201
+ **Key parameters:**
202
+
203
+ | Parameter | Description |
204
+ |---------------------------------|----------------------------------------------------------------------------------------------|
205
+ | `learning_rate` | Learning rate for optimization (typical range: 1e-5 to 1e-3) |
206
+ | `steps` | Total number of training steps |
207
+ | `batch_size` | Batch size per GPU (reduce if running out of memory) |
208
+ | `gradient_accumulation_steps` | Accumulate gradients over multiple steps |
209
+ | `scheduler_type` | LR scheduler: `"constant"`, `"linear"`, `"cosine"`, `"cosine_with_restarts"`, `"polynomial"` |
210
+ | `enable_gradient_checkpointing` | Trade training speed for GPU memory savings (recommended for large models) |
211
+
212
+ ### AccelerationConfig
213
+
214
+ Hardware acceleration and compute optimization settings.
215
+
216
+ ```yaml
217
+ acceleration:
218
+ mixed_precision_mode: "bf16" # "no", "fp16", or "bf16"
219
+ quantization: null # Quantization options
220
+ load_text_encoder_in_8bit: false # Load text encoder in 8-bit
221
+ ```
222
+
223
+ **Key parameters:**
224
+
225
+ | Parameter | Description |
226
+ |-----------------------------|------------------------------------------------------------------------------------|
227
+ | `mixed_precision_mode` | Precision mode - `"bf16"` recommended for modern GPUs |
228
+ | `quantization` | Model quantization: `null`, `"int8-quanto"`, `"int4-quanto"`, `"fp8-quanto"`, etc. |
229
+ | `load_text_encoder_in_8bit` | Load the Gemma text encoder in 8-bit to save GPU memory |
230
+
231
+ ### DataConfig
232
+
233
+ Data loading and processing configuration.
234
+
235
+ ```yaml
236
+ data:
237
+ preprocessed_data_root: "/path/to/preprocessed/data" # Path to precomputed dataset
238
+ num_dataloader_workers: 2 # Background data loading workers
239
+ ```
240
+
241
+ **Key parameters:**
242
+
243
+ | Parameter | Description |
244
+ |--------------------------|--------------------------------------------------------------------------------------------|
245
+ | `preprocessed_data_root` | Path to your preprocessed dataset (contains `latents/`, `conditions/`, etc.) |
246
+ | `num_dataloader_workers` | Number of parallel data loading processes (0 = synchronous loading, useful when debugging) |
247
+
248
+ ### ValidationConfig
249
+
250
+ Validation and inference settings for monitoring training progress.
251
+
252
+ ```yaml
253
+ validation:
254
+ prompts: # Validation prompts
255
+ - "A cat playing with a ball"
256
+ - "A dog running in a field"
257
+ negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
258
+ images: null # Optional image paths for image-to-video
259
+ reference_videos: null # Reference video paths (IC-LoRA only)
260
+ video_dims: [ 576, 576, 89 ] # Video dimensions [width, height, frames]
261
+ frame_rate: 25.0 # Frame rate for generated videos
262
+ seed: 42 # Random seed for reproducibility
263
+ inference_steps: 30 # Number of inference steps
264
+ interval: 100 # Steps between validation runs
265
+ videos_per_prompt: 1 # Videos generated per prompt
266
+ guidance_scale: 4.0 # CFG guidance strength
267
+ stg_scale: 1.0 # STG guidance strength (0.0 to disable)
268
+ stg_blocks: [ 29 ] # Transformer blocks to perturb for STG
269
+ stg_mode: "stg_av" # "stg_av" or "stg_v" (video only)
270
+ generate_audio: true # Whether to generate audio
271
+ skip_initial_validation: false # Skip validation at step 0
272
+ include_reference_in_output: false # Include reference video side-by-side (IC-LoRA)
273
+ ```
274
+
275
+ **Key parameters:**
276
+
277
+ | Parameter | Description |
278
+ |-------------------------------|--------------------------------------------------------------------------------------------------------------------------|
279
+ | `prompts` | List of text prompts for validation video generation |
280
+ | `images` | List of image paths for image-to-video validation (must match number of prompts) |
281
+ | `reference_videos` | List of reference video paths for IC-LoRA validation (must match number of prompts) |
282
+ | `video_dims` | Output dimensions `[width, height, frames]`. Width/height must be divisible by 32, frames must satisfy `frames % 8 == 1` |
283
+ | `interval` | Steps between validation runs (set to `null` to disable) |
284
+ | `guidance_scale` | CFG (Classifier-Free Guidance) scale. Recommended: 4.0 |
285
+ | `stg_scale` | STG (Spatio-Temporal Guidance) scale. 0.0 disables STG. Recommended: 1.0 |
286
+ | `stg_blocks` | Transformer blocks to perturb for STG. Recommended: `[29]` (single block) |
287
+ | `stg_mode` | STG mode: `"stg_av"` perturbs both audio and video, `"stg_v"` perturbs video only |
288
+ | `generate_audio` | Whether to generate audio in validation samples |
289
+ | `include_reference_in_output` | For IC-LoRA: concatenate reference video side-by-side with output |
290
+
291
+ ### CheckpointsConfig
292
+
293
+ Model checkpointing configuration.
294
+
295
+ ```yaml
296
+ checkpoints:
297
+ interval: 250 # Steps between checkpoint saves (null = disabled)
298
+ keep_last_n: 3 # Number of recent checkpoints to retain
299
+ precision: bfloat16 # Precision for saved weights (bfloat16 or float32)
300
+ ```
301
+
302
+ **Key parameters:**
303
+
304
+ | Parameter | Description |
305
+ |---------------|-------------------------------------------------------------------------------|
306
+ | `interval` | Steps between intermediate checkpoint saves (set to `null` to disable) |
307
+ | `keep_last_n` | Number of most recent checkpoints to keep (-1 = keep all) |
308
+ | `precision` | Precision for saved checkpoint weights: `"bfloat16"` (default) or `"float32"` |
309
+
310
+ ### HubConfig
311
+
312
+ Hugging Face Hub integration for automatic model uploads.
313
+
314
+ ```yaml
315
+ hub:
316
+ push_to_hub: false # Enable Hub uploading
317
+ hub_model_id: "username/model-name" # Hub repository ID
318
+ ```
319
+
320
+ **Key parameters:**
321
+
322
+ | Parameter | Description |
323
+ |----------------|------------------------------------------------------------------|
324
+ | `push_to_hub` | Whether to automatically push trained models to Hugging Face Hub |
325
+ | `hub_model_id` | Repository ID in format `"username/repository-name"` |
326
+
327
+ ### WandbConfig
328
+
329
+ Weights & Biases logging configuration.
330
+
331
+ ```yaml
332
+ wandb:
333
+ enabled: false # Enable W&B logging
334
+ project: "ltx-2-trainer" # W&B project name
335
+ entity: null # W&B username or team
336
+ tags: [ ] # Tags for the run
337
+ log_validation_videos: true # Log validation videos to W&B
338
+ ```
339
+
340
+ **Key parameters:**
341
+
342
+ | Parameter | Description |
343
+ |-------------------------|--------------------------------------------------|
344
+ | `enabled` | Whether to enable W&B logging |
345
+ | `project` | W&B project name |
346
+ | `entity` | W&B username or team (null uses default account) |
347
+ | `log_validation_videos` | Whether to log validation videos to W&B |
348
+
349
+ ### FlowMatchingConfig
350
+
351
+ Flow matching training configuration for timestep sampling.
352
+
353
+ ```yaml
354
+ flow_matching:
355
+ timestep_sampling_mode: "shifted_logit_normal" # Timestep sampling strategy
356
+ timestep_sampling_params: { } # Additional sampling parameters
357
+ ```
358
+
359
+ **Key parameters:**
360
+
361
+ | Parameter | Description |
362
+ |----------------------------|------------------------------------------------------------|
363
+ | `timestep_sampling_mode` | Sampling strategy: `"uniform"` or `"shifted_logit_normal"` |
364
+ | `timestep_sampling_params` | Additional parameters for the sampling strategy |
365
+
366
+ ## 🚀 Next Steps
367
+
368
+ Once you've configured your training parameters:
369
+
370
+ - Set up your dataset using [Dataset Preparation](dataset-preparation.md)
371
+ - Choose your training approach in [Training Modes](training-modes.md)
372
+ - Start training with the [Training Guide](training-guide.md)
packages/ltx-trainer/docs/custom-training-strategies.md ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementing Custom Training Strategies
2
+
3
+ This guide explains how to implement your own training strategy for specialized use cases like audio-only training,
4
+ video inpainting, or other custom training recipes.
5
+
6
+ ## 📋 Overview
7
+
8
+ The trainer uses the **Strategy Pattern** to separate training logic from the core training loop. Each strategy defines:
9
+
10
+ 1. **What data is needed** - Which preprocessed data directories to load
11
+ 2. **How to prepare inputs** - Transform batch data into model inputs
12
+ 3. **How to compute loss** - Calculate the training objective
13
+
14
+ This architecture lets you implement new training modes without modifying the core trainer code.
15
+
16
+ ### When You Need a Custom Strategy
17
+
18
+ Consider implementing a custom strategy when you need:
19
+
20
+ - **Different input modalities** (e.g., audio-only, audio-to-video conditioning)
21
+ - **Additional conditioning signals** (e.g., masks for inpainting, depth maps)
22
+ - **Custom loss computation** (e.g., weighted losses, auxiliary losses)
23
+ - **Different noise application patterns** (e.g., partial masking)
24
+
25
+ ## 🏗️ Architecture Overview
26
+
27
+ ### How Strategies Fit Into the Trainer
28
+
29
+ The trainer delegates all training-mode-specific logic to the strategy:
30
+
31
+ 1. **Initialization** — The trainer calls `get_data_sources()` to determine which preprocessed data directories to load
32
+ 2. **Each training step:**
33
+ - Calls `prepare_training_inputs()` to transform the raw batch into model-ready inputs
34
+ - Runs the transformer forward pass
35
+ - Calls `compute_loss()` to compute the training objective
36
+
37
+ The trainer handles everything else: optimization, checkpointing, validation, and distributed training.
38
+
39
+ ### Key Components
40
+
41
+ | Component | Purpose |
42
+ |-----------------------------------------------------------------------------------------|--------------------------------------------------------------|
43
+ | [`TrainingStrategyConfigBase`](../src/ltx_trainer/training_strategies/base_strategy.py) | Base class for strategy configuration (Pydantic model) |
44
+ | [`TrainingStrategy`](../src/ltx_trainer/training_strategies/base_strategy.py) | Abstract base class defining the strategy interface |
45
+ | [`ModelInputs`](../src/ltx_trainer/training_strategies/base_strategy.py) | Dataclass containing prepared inputs for the transformer |
46
+ | [`Modality`](../../ltx-core/src/ltx_core/model/transformer/modality.py) | ltx-core dataclass representing video or audio modality data |
47
+
48
+ ## 📝 Step-by-Step Implementation
49
+
50
+ ### Step 1: Plan Your Strategy
51
+
52
+ Before writing code, answer these questions:
53
+
54
+ 1. **What additional data does your strategy need?**
55
+ - Example: Inpainting needs mask latents alongside video latents
56
+ - Example: Audio-to-video needs reference audio embeddings
57
+
58
+ 2. **What does conditioning look like?**
59
+ - Which tokens should be noised vs. kept clean?
60
+ - How should conditioning tokens be structured (e.g., first frame, reference video, mask)?
61
+
62
+ 3. **How should loss be computed?**
63
+ - Which tokens contribute to the loss?
64
+ - Are there multiple loss terms to combine?
65
+
66
+ ### Step 2: Extend Data Preprocessing (If Needed)
67
+
68
+ If your strategy requires additional preprocessed data beyond video latents, audio latents, and text embeddings, you'll
69
+ need to extend the preprocessing pipeline.
70
+
71
+ #### Option A: Modify `process_dataset.py`
72
+
73
+ For integrated preprocessing, add new arguments and processing steps to the main script. For example, to add mask
74
+ preprocessing:
75
+
76
+ ```python
77
+ # In process_dataset.py, add a new argument
78
+ @app.command()
79
+ def main(
80
+ # ... existing arguments ...
81
+ mask_column: str | None = typer.Option(
82
+ default=None,
83
+ help="Column name containing mask video paths (for inpainting)",
84
+ ),
85
+ ) -> None:
86
+ # ... existing processing ...
87
+
88
+ # Process masks if provided
89
+ if mask_column:
90
+ logger.info("Processing mask videos for inpainting training...")
91
+ mask_latents_dir = output_base / "mask_latents"
92
+
93
+ compute_latents(
94
+ dataset_file=dataset_path,
95
+ video_column=mask_column,
96
+ resolution_buckets=parsed_resolution_buckets,
97
+ output_dir=str(mask_latents_dir),
98
+ model_path=model_path,
99
+ # ... other args ...
100
+ )
101
+ ```
102
+
103
+ #### Option B: Create a Standalone Script
104
+
105
+ For complex preprocessing that doesn't fit naturally into the existing pipeline, create a dedicated script
106
+ (e.g., `scripts/process_masks.py`). Use [`scripts/compute_reference.py`](../scripts/compute_reference.py) as a
107
+ template - it shows how to process paired data and update the dataset JSON.
108
+
109
+ #### Expected Output Structure
110
+
111
+ Your preprocessing should create a directory structure that the strategy can reference:
112
+
113
+ ```
114
+ preprocessed_data_root/
115
+ ├── latents/ # Video latents (standard)
116
+ ├── conditions/ # Text embeddings (standard)
117
+ ├── audio_latents/ # Audio latents (if with_audio)
118
+ ├── mask_latents/ # Your custom data directory
119
+ └── reference_latents/ # Reference videos (for IC-LoRA)
120
+ ```
121
+
122
+ ### Step 3: Create the Strategy Configuration
123
+
124
+ Create a new file for your strategy (e.g., `src/ltx_trainer/training_strategies/inpainting.py`):
125
+
126
+ ```python
127
+ """Inpainting training strategy.
128
+
129
+ This strategy implements video inpainting training where:
130
+ - Mask latents indicate which regions to inpaint
131
+ - Loss is computed only on masked (inpainted) regions
132
+ """
133
+
134
+ from typing import Any, Literal
135
+
136
+ import torch
137
+ from pydantic import Field
138
+ from torch import Tensor
139
+
140
+ from ltx_core.model.transformer.modality import Modality
141
+ from ltx_trainer.timestep_samplers import TimestepSampler
142
+ from ltx_trainer.training_strategies.base_strategy import (
143
+ ModelInputs,
144
+ TrainingStrategy,
145
+ TrainingStrategyConfigBase,
146
+ )
147
+
148
+
149
+ class InpaintingConfig(TrainingStrategyConfigBase):
150
+ """Configuration for inpainting training strategy."""
151
+
152
+ # The 'name' field acts as a discriminator for the config union
153
+ name: Literal["inpainting"] = "inpainting"
154
+
155
+ mask_latents_dir: str = Field(
156
+ default="mask_latents",
157
+ description="Directory name for mask latents",
158
+ )
159
+
160
+ # Add any strategy-specific parameters
161
+ mask_threshold: float = Field(
162
+ default=0.5,
163
+ description="Threshold for binary mask conversion",
164
+ ge=0.0,
165
+ le=1.0,
166
+ )
167
+ ```
168
+
169
+ **Key points:**
170
+
171
+ - Inherit from `TrainingStrategyConfigBase`
172
+ - Use `Literal["your_strategy_name"]` for the `name` field - this enables automatic strategy selection
173
+ - Use Pydantic `Field` for validation and documentation
174
+
175
+ ### Step 4: Implement the Strategy Class
176
+
177
+ ```python
178
+ class InpaintingStrategy(TrainingStrategy):
179
+ """Inpainting training strategy.
180
+
181
+ Trains the model to fill in masked regions of videos while
182
+ keeping unmasked regions as conditioning.
183
+ """
184
+
185
+ config: InpaintingConfig
186
+
187
+ def __init__(self, config: InpaintingConfig):
188
+ super().__init__(config)
189
+
190
+ @property
191
+ def requires_audio(self) -> bool:
192
+ """Whether this strategy requires audio components."""
193
+ return False # Set to True if your strategy needs audio
194
+
195
+ def get_data_sources(self) -> dict[str, str]:
196
+ """Define which data directories to load.
197
+
198
+ Returns a mapping of directory names to batch keys.
199
+ The trainer will load .pt files from each directory and
200
+ make them available in the batch under the specified key.
201
+ """
202
+ return {
203
+ "latents": "latents", # -> batch["latents"]
204
+ "conditions": "conditions", # -> batch["conditions"]
205
+ self.config.mask_latents_dir: "masks", # -> batch["masks"]
206
+ }
207
+
208
+ def prepare_training_inputs(
209
+ self,
210
+ batch: dict[str, Any],
211
+ timestep_sampler: TimestepSampler,
212
+ ) -> ModelInputs:
213
+ """Transform batch data into model inputs.
214
+
215
+ This is where the core training logic lives:
216
+ 1. Extract and patchify latents
217
+ 2. Sample noise and apply it appropriately
218
+ 3. Create conditioning masks
219
+ 4. Build Modality objects for the transformer
220
+ """
221
+ # Get video latents [B, C, F, H, W]
222
+ latents_data = batch["latents"]
223
+ video_latents = latents_data["latents"]
224
+
225
+ # Get dimensions
226
+ num_frames = latents_data["num_frames"][0].item()
227
+ height = latents_data["height"][0].item()
228
+ width = latents_data["width"][0].item()
229
+
230
+ # Patchify: [B, C, F, H, W] -> [B, seq_len, C]
231
+ video_latents = self._video_patchifier.patchify(video_latents)
232
+
233
+ batch_size, seq_len, _ = video_latents.shape
234
+ device = video_latents.device
235
+ dtype = video_latents.dtype
236
+
237
+ # Get mask latents and process them
238
+ mask_data = batch["masks"]
239
+ mask_latents = mask_data["latents"]
240
+ mask_latents = self._video_patchifier.patchify(mask_latents)
241
+
242
+ # Create binary mask: True = inpaint this region, False = keep original
243
+ inpaint_mask = mask_latents.mean(dim=-1) > self.config.mask_threshold
244
+
245
+ # Sample noise and sigmas
246
+ sigmas = timestep_sampler.sample_for(video_latents)
247
+ noise = torch.randn_like(video_latents)
248
+
249
+ # Apply noise only to inpaint regions
250
+ sigmas_expanded = sigmas.view(-1, 1, 1)
251
+ noisy_latents = (1 - sigmas_expanded) * video_latents + sigmas_expanded * noise
252
+
253
+ # Keep original latents for non-inpaint regions (conditioning)
254
+ inpaint_mask_expanded = inpaint_mask.unsqueeze(-1)
255
+ noisy_latents = torch.where(inpaint_mask_expanded, noisy_latents, video_latents)
256
+
257
+ # Create per-token timesteps
258
+ # Conditioning tokens (non-inpaint) get timestep=0
259
+ # Inpaint tokens get the sampled sigma
260
+ timesteps = self._create_per_token_timesteps(~inpaint_mask, sigmas.squeeze())
261
+
262
+ # Compute targets (velocity prediction: noise - clean)
263
+ targets = noise - video_latents
264
+
265
+ # Get text embeddings
266
+ conditions = batch["conditions"]
267
+ video_prompt_embeds = conditions["video_prompt_embeds"]
268
+ prompt_attention_mask = conditions["prompt_attention_mask"]
269
+
270
+ # Generate position embeddings
271
+ positions = self._get_video_positions(
272
+ num_frames=num_frames,
273
+ height=height,
274
+ width=width,
275
+ batch_size=batch_size,
276
+ fps=24.0, # Or get from latents_data
277
+ device=device,
278
+ dtype=dtype,
279
+ )
280
+
281
+ # Create video Modality
282
+ video_modality = Modality(
283
+ enabled=True,
284
+ latent=noisy_latents,
285
+ sigma=sigmas,
286
+ timesteps=timesteps,
287
+ positions=positions,
288
+ context=video_prompt_embeds,
289
+ context_mask=prompt_attention_mask,
290
+ )
291
+
292
+ # Loss mask: only compute loss on inpaint regions
293
+ loss_mask = inpaint_mask
294
+
295
+ return ModelInputs(
296
+ video=video_modality,
297
+ audio=None,
298
+ video_targets=targets,
299
+ audio_targets=None,
300
+ video_loss_mask=loss_mask,
301
+ audio_loss_mask=None,
302
+ )
303
+
304
+ def compute_loss(
305
+ self,
306
+ video_pred: Tensor,
307
+ audio_pred: Tensor | None,
308
+ inputs: ModelInputs,
309
+ ) -> Tensor:
310
+ """Compute training loss on inpaint regions only."""
311
+ # MSE loss
312
+ loss = (video_pred - inputs.video_targets).pow(2)
313
+
314
+ # Apply loss mask
315
+ loss_mask = inputs.video_loss_mask.unsqueeze(-1).float()
316
+ loss = loss.mul(loss_mask).div(loss_mask.mean() + 1e-8)
317
+
318
+ return loss.mean()
319
+ ```
320
+
321
+ ### Step 5: Register the Strategy
322
+
323
+ You need to register your strategy in two places:
324
+
325
+ **1. Update [`src/ltx_trainer/training_strategies/__init__.py`](../src/ltx_trainer/training_strategies/__init__.py):**
326
+
327
+ ```python
328
+ # Add import for your strategy
329
+ from ltx_trainer.training_strategies.inpainting import InpaintingConfig, InpaintingStrategy
330
+
331
+ # Add to the TrainingStrategyConfig type alias
332
+ TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig | InpaintingConfig
333
+
334
+ # Add to __all__
335
+ __all__ = [
336
+ # ... existing exports ...
337
+ "InpaintingConfig",
338
+ "InpaintingStrategy",
339
+ ]
340
+
341
+
342
+ # Add case in get_training_strategy()
343
+ def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy:
344
+ match config:
345
+ # ... existing cases ...
346
+ case InpaintingConfig():
347
+ strategy = InpaintingStrategy(config)
348
+ ```
349
+
350
+ **2. Update [`src/ltx_trainer/config.py`](../src/ltx_trainer/config.py):**
351
+
352
+ ```python
353
+ # Add import
354
+ from ltx_trainer.training_strategies.inpainting import InpaintingConfig
355
+
356
+ # Add to the TrainingStrategyConfig union with a Tag matching your strategy name
357
+ TrainingStrategyConfig = Annotated[
358
+ Annotated[TextToVideoConfig, Tag("text_to_video")]
359
+ | Annotated[VideoToVideoConfig, Tag("video_to_video")]
360
+ | Annotated[InpaintingConfig, Tag("inpainting")], # Add your config
361
+ Discriminator(_get_strategy_discriminator),
362
+ ]
363
+ ```
364
+
365
+ ### Step 6: Create a Configuration File
366
+
367
+ Create an example config in `configs/`:
368
+
369
+ ```yaml
370
+ # configs/ltx2_inpainting_lora.yaml
371
+
372
+ model:
373
+ model_path: "/path/to/ltx2.safetensors"
374
+ text_encoder_path: "/path/to/gemma"
375
+ training_mode: "lora"
376
+
377
+ training_strategy:
378
+ name: "inpainting" # Must match your Literal type
379
+ mask_latents_dir: "mask_latents"
380
+ mask_threshold: 0.5
381
+
382
+ lora:
383
+ rank: 32
384
+ alpha: 32
385
+ target_modules:
386
+ - "to_k"
387
+ - "to_q"
388
+ - "to_v"
389
+ - "to_out.0"
390
+
391
+ data:
392
+ preprocessed_data_root: "/path/to/preprocessed/dataset"
393
+
394
+ optimization:
395
+ learning_rate: 1e-4
396
+ steps: 2000
397
+ batch_size: 1
398
+
399
+ # ... other config sections ...
400
+ ```
401
+
402
+ ## 🔧 Helper Methods Reference
403
+
404
+ The base `TrainingStrategy` class provides these helper methods:
405
+
406
+ | Method | Purpose |
407
+ |----------------------------------------------|-------------------------------------------------|
408
+ | `_video_patchifier.patchify(latents)` | Convert `[B, C, F, H, W]` → `[B, seq_len, C]` |
409
+ | `_audio_patchifier.patchify(latents)` | Convert `[B, C, T, F]` → `[B, T, C*F]` |
410
+ | `_get_video_positions(...)` | Generate position embeddings for video |
411
+ | `_get_audio_positions(...)` | Generate position embeddings for audio |
412
+ | `_create_per_token_timesteps(mask, sigma)` | Create timesteps with 0 for conditioning tokens |
413
+ | `_create_first_frame_conditioning_mask(...)` | Create mask for first-frame conditioning |
414
+
415
+ ## 📊 Understanding ModelInputs
416
+
417
+ The `ModelInputs` dataclass contains everything needed for the forward pass and loss computation:
418
+
419
+ ```python
420
+ @dataclass
421
+ class ModelInputs:
422
+ video: Modality # Video modality data
423
+ audio: Modality | None # Audio modality (None if video-only)
424
+
425
+ video_targets: Tensor # Target values for loss (velocity)
426
+ audio_targets: Tensor | None
427
+
428
+ video_loss_mask: Tensor # Boolean: True = compute loss for this token
429
+ audio_loss_mask: Tensor | None
430
+
431
+ ref_seq_len: int | None = None # For IC-LoRA: reference sequence length
432
+ ```
433
+
434
+ ## 📊 Understanding Modality
435
+
436
+ The `Modality` dataclass (from ltx-core) represents a single modality's data:
437
+
438
+ ```python
439
+ @dataclass(frozen=True)
440
+ class Modality:
441
+ enabled: bool # Whether this modality is active
442
+ latent: Tensor # [B, seq_len, C] - the latent tokens
443
+ timesteps: Tensor # [B, seq_len] - per-token timesteps (sigmas)
444
+ positions: Tensor # [B, dims, seq_len, 2] - position bounds
445
+ context: Tensor # [B, ctx_len, C] - text embeddings
446
+ context_mask: Tensor # [B, ctx_len] - attention mask for context
447
+ ```
448
+
449
+ > [!NOTE]
450
+ > **Per-token timesteps:** Each token in the sequence has its own timestep. Conditioning tokens—those that should remain
451
+ > un-noised—must have `timestep=0`. This is how the model distinguishes clean reference tokens from tokens to denoise. Use
452
+ `_create_per_token_timesteps(conditioning_mask, sigma)` to set this up correctly.
453
+
454
+ > [!NOTE]
455
+ > `Modality` is immutable (frozen dataclass). Use `dataclasses.replace()` to create modified copies.
456
+
457
+ ## ✅ Testing Your Strategy
458
+
459
+ 1. **Verify your training configuration is valid:**
460
+ ```bash
461
+ uv run python -c "
462
+ from ltx_trainer.config import LtxTrainerConfig
463
+ import yaml
464
+
465
+ with open('configs/ltx2_inpainting_lora.yaml') as f:
466
+ config = LtxTrainerConfig(**yaml.safe_load(f))
467
+ print(f'Strategy: {config.training_strategy.name}')
468
+ "
469
+ ```
470
+
471
+ 2. **Test strategy instantiation:**
472
+ ```bash
473
+ uv run python -c "
474
+ from ltx_trainer.training_strategies import get_training_strategy
475
+ from ltx_trainer.training_strategies.inpainting import InpaintingConfig
476
+
477
+ config = InpaintingConfig()
478
+ strategy = get_training_strategy(config)
479
+ print(f'Data sources: {strategy.get_data_sources()}')
480
+ "
481
+ ```
482
+
483
+ 3. **Run a short training test:**
484
+ ```bash
485
+ uv run python scripts/train.py configs/ltx2_inpainting_lora.yaml
486
+ ```
487
+
488
+ ## 💡 Tips and Best Practices
489
+
490
+ ### Debugging
491
+
492
+ - Set `data.num_dataloader_workers: 0` to get clearer error messages
493
+ - Use a small dataset and few steps for initial testing
494
+ - Check tensor shapes at each step with print statements
495
+
496
+ ## 🔗 Related Documentation
497
+
498
+ - [Training Modes](training-modes.md) - Overview of built-in training modes
499
+ - [Configuration Reference](configuration-reference.md) - All configuration options
500
+ - [Dataset Preparation](dataset-preparation.md) - Preprocessing workflow
501
+ - [ltx-core Documentation](../../ltx-core/README.md) - Core model components
502
+
503
+ ## 📚 Reference: Existing Strategies
504
+
505
+ Study these implementations for guidance:
506
+
507
+ | Strategy | Complexity | Key Features |
508
+ |------------------------------------------------------------------------------------|------------|------------------------------------------------|
509
+ | [`TextToVideoStrategy`](../src/ltx_trainer/training_strategies/text_to_video.py) | Simple | First-frame conditioning, optional audio |
510
+ | [`VideoToVideoStrategy`](../src/ltx_trainer/training_strategies/video_to_video.py) | Medium | Reference video concatenation, split loss mask |
packages/ltx-trainer/docs/dataset-preparation.md ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Preparation Guide
2
+
3
+ This guide covers the complete workflow for preparing and preprocessing your dataset for training.
4
+
5
+ ## 📋 Overview
6
+
7
+ The general dataset preparation workflow is:
8
+
9
+ 1. **(Optional)** Split long videos into scenes using `split_scenes.py`
10
+ 2. **(Optional)** Generate captions for your videos using `caption_videos.py`
11
+ 3. **Preprocess your dataset** using `process_dataset.py` to compute and cache video/audio latents and text embeddings
12
+ 4. **Run the trainer** with your preprocessed dataset
13
+
14
+ ## 🎬 Step 1: Split Scenes
15
+
16
+ If you're starting with raw, long-form videos (e.g., downloaded from YouTube), you should first split them into shorter, coherent scenes.
17
+
18
+ ```bash
19
+ uv run python scripts/split_scenes.py input.mp4 scenes_output_dir/ \
20
+ --filter-shorter-than 5s
21
+ ```
22
+
23
+ This will create multiple video clips in `scenes_output_dir`.
24
+ These clips will be the input for the captioning step, if you choose to use it.
25
+
26
+ The script supports many configuration options for scene detection (detector algorithms, thresholds, minimum scene lengths, etc.):
27
+
28
+ ```bash
29
+ uv run python scripts/split_scenes.py --help
30
+ ```
31
+
32
+ ## 📝 Step 2: Caption Videos
33
+
34
+ If your dataset doesn't include captions, you can automatically generate them using multimodal models that understand both video and audio.
35
+
36
+ ```bash
37
+ uv run python scripts/caption_videos.py scenes_output_dir/ \
38
+ --output scenes_output_dir/dataset.json
39
+ ```
40
+
41
+ If you're running into VRAM issues, try enabling 8-bit quantization to reduce memory usage:
42
+
43
+ ```bash
44
+ uv run python scripts/caption_videos.py scenes_output_dir/ \
45
+ --output scenes_output_dir/dataset.json \
46
+ --use-8bit
47
+ ```
48
+
49
+ This will create a `dataset.json` file containing video paths and their captions.
50
+
51
+ **Captioning options:**
52
+
53
+ | Option | Description |
54
+ |--------|-------------|
55
+ | `--captioner-type` | `qwen_omni` (default, local) or `gemini_flash` (API) |
56
+ | `--use-8bit` | Enable 8-bit quantization for lower VRAM usage |
57
+ | `--no-audio` | Disable audio processing (video-only captions) |
58
+ | `--override` | Re-caption files that already have captions |
59
+ | `--api-key` | API key for Gemini Flash (or set `GOOGLE_API_KEY` env var) |
60
+
61
+ **Caption format:**
62
+
63
+ The captioner produces structured captions with sections for:
64
+ - **Visual content**: People, objects, actions, settings, colors, movements
65
+ - **Speech transcription**: Word-for-word transcription of spoken content
66
+ - **Sounds**: Music, ambient sounds, sound effects
67
+ - **On-screen text**: Any visible text overlays
68
+
69
+ > [!NOTE]
70
+ > The automatically generated captions may contain inaccuracies or hallucinated content.
71
+ > We recommend reviewing and correcting the generated captions in your `dataset.json` file before proceeding to preprocessing.
72
+
73
+ ## ⚡ Step 3: Dataset Preprocessing
74
+
75
+ This step preprocesses your video dataset by:
76
+
77
+ 1. Resizing and cropping videos to fit specified resolution buckets
78
+ 2. Computing and caching video latent representations
79
+ 3. Computing and caching text embeddings for captions
80
+ 4. (Optional) Computing and caching audio latents
81
+
82
+ > [!WARNING]
83
+ > Very large videos (especially high spatial resolution and/or many frames) can cause GPU out-of-memory (OOM)
84
+ > during preprocessing/encoding.
85
+ > The simplest fix is to reduce the target resolution (spatially: width/height) and/or the number of frames
86
+ > (temporally) by using `--resolution-buckets` with smaller dimensions (lower width/height and/or fewer frames).
87
+
88
+ ### Basic Usage
89
+
90
+ ```bash
91
+ uv run python scripts/process_dataset.py dataset.json \
92
+ --resolution-buckets "960x544x49" \
93
+ --model-path /path/to/ltx-2-model.safetensors \
94
+ --text-encoder-path /path/to/gemma-model
95
+ ```
96
+
97
+ ### With Audio Processing
98
+
99
+ For audio-video training, add the `--with-audio` flag:
100
+
101
+ ```bash
102
+ uv run python scripts/process_dataset.py dataset.json \
103
+ --resolution-buckets "960x544x49" \
104
+ --model-path /path/to/ltx-2-model.safetensors \
105
+ --text-encoder-path /path/to/gemma-model \
106
+ --with-audio
107
+ ```
108
+
109
+ ### 📊 Dataset Format
110
+
111
+ The trainer supports either videos or single images.
112
+ Note that your dataset must be homogeneous - either all videos or all images, mixing is not supported.
113
+
114
+ > [!TIP]
115
+ > **Image Datasets:** When using images, follow the same preprocessing steps and format requirements as with videos,
116
+ > but use `1` for the frame count in the resolution bucket (e.g., `960x544x1`).
117
+
118
+ The dataset must be a CSV, JSON, or JSONL metadata file with columns for captions and video paths:
119
+
120
+ **JSON format example:**
121
+
122
+ ```json
123
+ [
124
+ {
125
+ "caption": "A cat playing with a ball of yarn",
126
+ "media_path": "videos/cat_playing.mp4"
127
+ },
128
+ {
129
+ "caption": "A dog running in the park",
130
+ "media_path": "videos/dog_running.mp4"
131
+ }
132
+ ]
133
+ ```
134
+
135
+ **JSONL format example:**
136
+
137
+ ```jsonl
138
+ {"caption": "A cat playing with a ball of yarn", "media_path": "videos/cat_playing.mp4"}
139
+ {"caption": "A dog running in the park", "media_path": "videos/dog_running.mp4"}
140
+ ```
141
+
142
+ **CSV format example:**
143
+
144
+ ```csv
145
+ caption,media_path
146
+ "A cat playing with a ball of yarn","videos/cat_playing.mp4"
147
+ "A dog running in the park","videos/dog_running.mp4"
148
+ ```
149
+
150
+ ### 📐 Resolution Buckets
151
+
152
+ Videos are organized into "buckets" of specific dimensions (width × height × frames).
153
+ Each video is assigned to the nearest matching bucket.
154
+ You can preprocess with one or multiple resolution buckets.
155
+ When training with multiple resolution buckets, you must use a batch size of 1.
156
+
157
+ The dimensions of each bucket must follow these constraints due to LTX-2's VAE architecture:
158
+
159
+ - **Spatial dimensions** (width and height) must be multiples of 32
160
+ - **Number of frames** must satisfy `frames % 8 == 1` (e.g., 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121, etc.)
161
+
162
+ **Guidelines for choosing training resolution:**
163
+
164
+ - For high-quality, detailed videos: use larger spatial dimensions (e.g. 768x448) with fewer frames (e.g. 89)
165
+ - For longer, motion-focused videos: use smaller spatial dimensions (512×512) with more frames (121)
166
+ - Memory usage increases with both spatial and temporal dimensions
167
+
168
+ **Example usage:**
169
+
170
+ ```bash
171
+ uv run python scripts/process_dataset.py dataset.json \
172
+ --resolution-buckets "960x544x49" \
173
+ --model-path /path/to/ltx-2-model.safetensors \
174
+ --text-encoder-path /path/to/gemma-model
175
+ ```
176
+
177
+ Multiple buckets are supported by separating entries with `;`:
178
+
179
+ ```bash
180
+ uv run python scripts/process_dataset.py dataset.json \
181
+ --resolution-buckets "960x544x49;512x512x49" \
182
+ --model-path /path/to/ltx-2-model.safetensors \
183
+ --text-encoder-path /path/to/gemma-model
184
+ ```
185
+
186
+ **Video processing workflow:**
187
+
188
+ 1. Videos are **resized** maintaining aspect ratio until either width or height matches the target
189
+ 2. The larger dimension is **center cropped** to match the bucket's dimensions
190
+ 3. Only the **first X frames are taken** to match the bucket's frame count, remaining frames are ignored
191
+
192
+ > [!NOTE]
193
+ > The sequence length processed by the transformer model can be calculated as:
194
+ >
195
+ > ```
196
+ > sequence_length = (H/32) * (W/32) * ((F-1)/8 + 1)
197
+ > ```
198
+ >
199
+ > Where:
200
+ > - H = Height of video
201
+ > - W = Width of video
202
+ > - F = Number of frames
203
+ > - 32 = VAE's spatial downsampling factor
204
+ > - 8 = VAE's temporal downsampling factor
205
+ >
206
+ > For example, a 768×448×89 video would have sequence length:
207
+ > ```
208
+ > (768/32) * (448/32) * ((89-1)/8 + 1) = 24 * 14 * 12 = 4,032
209
+ > ```
210
+ >
211
+ > Keep this in mind when choosing video dimensions, as longer sequences require more GPU memory.
212
+
213
+ > [!WARNING]
214
+ > When training with multiple resolution buckets, you must use a batch size of 1
215
+ > (i.e., set `optimization.batch_size: 1` in your training config).
216
+
217
+ ### 📁 Output Structure
218
+
219
+ The preprocessed data is saved in a `.precomputed` directory:
220
+
221
+ ```
222
+ dataset/
223
+ └── .precomputed/
224
+ ├── latents/ # Cached video latents
225
+ ├── conditions/ # Cached text embeddings
226
+ ├── audio_latents/ # (only if --with-audio) Cached audio latents
227
+ └── reference_latents/ # (only for IC-LoRA) Cached reference video latents
228
+ ```
229
+
230
+ ## 🪄 IC-LoRA Reference Video Preprocessing
231
+
232
+ For IC-LoRA training, you need to preprocess datasets that include reference videos.
233
+ Reference videos provide the conditioning input while target videos represent the desired transformed output.
234
+
235
+ ### Dataset Format with Reference Videos
236
+
237
+ **JSON format:**
238
+
239
+ ```json
240
+ [
241
+ {
242
+ "caption": "A cat playing with a ball of yarn",
243
+ "media_path": "videos/cat_playing.mp4",
244
+ "reference_path": "references/cat_playing_depth.mp4"
245
+ }
246
+ ]
247
+ ```
248
+
249
+ **JSONL format:**
250
+
251
+ ```jsonl
252
+ {"caption": "A cat playing with a ball of yarn", "media_path": "videos/cat_playing.mp4", "reference_path": "references/cat_playing_depth.mp4"}
253
+ {"caption": "A dog running in the park", "media_path": "videos/dog_running.mp4", "reference_path": "references/dog_running_depth.mp4"}
254
+ ```
255
+
256
+ ### Preprocessing with Reference Videos
257
+
258
+ To preprocess a dataset with reference videos, add the `--reference-column` argument specifying the name of the field
259
+ in your dataset JSON/JSONL/CSV that contains the reference video paths:
260
+
261
+ ```bash
262
+ uv run python scripts/process_dataset.py dataset.json \
263
+ --resolution-buckets "960x544x49" \
264
+ --model-path /path/to/ltx-2-model.safetensors \
265
+ --text-encoder-path /path/to/gemma-model \
266
+ --reference-column "reference_path"
267
+ ```
268
+
269
+ This will create an additional `reference_latents/` directory containing the preprocessed reference video latents.
270
+
271
+
272
+ ### Generating Reference Videos
273
+
274
+ **Dataset Requirements for IC-LoRA:**
275
+
276
+ - Your dataset must contain paired videos where each target video has a corresponding reference video
277
+ - Reference and target videos must have *identical* resolution and length
278
+ - Both reference and target videos should be preprocessed together using the same resolution buckets
279
+
280
+ We provide an example script, [`scripts/compute_reference.py`](../scripts/compute_reference.py), to generate reference
281
+ videos for a given dataset. The default implementation generates Canny edge reference videos.
282
+
283
+ ```bash
284
+ uv run python scripts/compute_reference.py scenes_output_dir/ \
285
+ --output scenes_output_dir/dataset.json
286
+ ```
287
+
288
+ The script accepts a JSON file as the dataset configuration and updates it in-place by adding the filenames of the generated reference videos.
289
+
290
+ If you want to generate a different type of condition (depth maps, pose skeletons, etc.), modify or replace the `compute_reference()` function within this script.
291
+
292
+ ### Example Dataset
293
+
294
+ For reference, see our **[Canny Control Dataset](https://huggingface.co/datasets/Lightricks/Canny-Control-Dataset)** which demonstrates proper IC-LoRA dataset structure with paired videos and Canny edge maps.
295
+
296
+
297
+ ## 🎯 LoRA Trigger Words
298
+
299
+ When training a LoRA, you can specify a trigger token that will be prepended to all captions:
300
+
301
+ ```bash
302
+ uv run python scripts/process_dataset.py dataset.json \
303
+ --resolution-buckets "960x544x49" \
304
+ --model-path /path/to/ltx-2-model.safetensors \
305
+ --text-encoder-path /path/to/gemma-model \
306
+ --lora-trigger "MYTRIGGER"
307
+ ```
308
+
309
+ This acts as a trigger word that activates the LoRA during inference when you include the same token in your prompts.
310
+
311
+ > [!NOTE]
312
+ > There is no need to manually insert the trigger word into your dataset JSON/JSONL/CSV file.
313
+ > The trigger word specified with `--lora-trigger` is automatically prepended to each caption during preprocessing.
314
+
315
+ ## 🔍 Decoding Videos for Verification
316
+
317
+ If you add the `--decode` flag, the script will VAE-decode the precomputed latents and save the resulting videos
318
+ in `.precomputed/decoded_videos`. When audio preprocessing is enabled (`--with-audio`), audio latents will also be
319
+ decoded and saved to `.precomputed/decoded_audio`. This allows you to visually and audibly inspect the processed data.
320
+
321
+ ```bash
322
+ uv run python scripts/process_dataset.py dataset.json \
323
+ --resolution-buckets "960x544x49" \
324
+ --model-path /path/to/ltx-2-model.safetensors \
325
+ --text-encoder-path /path/to/gemma-model \
326
+ --decode
327
+ ```
328
+
329
+ For single-frame images, the decoded latents will be saved as PNG files rather than MP4 videos.
330
+
331
+ ## 🚀 Next Steps
332
+
333
+ Once your dataset is preprocessed, you can proceed to:
334
+
335
+ - Configure your training parameters in [Configuration Reference](configuration-reference.md)
336
+ - Choose your training approach in [Training Modes](training-modes.md)
337
+ - Start training with the [Training Guide](training-guide.md)
338
+
339
+ > [!TIP]
340
+ > If your training recipe requires additional preprocessed data (e.g., masks, conditioning signals), see
341
+ > [Implementing Custom Training Strategies](custom-training-strategies.md) for guidance on extending the
342
+ > preprocessing pipeline.
packages/ltx-trainer/docs/quick-start.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick Start Guide
2
+
3
+ Get up and running with LTX-2 training in just a few steps!
4
+
5
+ ## 📋 Prerequisites
6
+
7
+ Before you begin, ensure you have:
8
+
9
+ 1. **LTX-2 Model Checkpoint** - A local `.safetensors` file containing the LTX-2 model weights.
10
+ Download `ltx-2-19b-dev.safetensors` from: [HuggingFace Hub](https://huggingface.co/Lightricks/LTX-2)
11
+ 2. **Gemma Text Encoder** - A local directory containing the Gemma model (required for LTX-2).
12
+ Download from: [HuggingFace Hub](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/)
13
+ 3. **Linux with CUDA** - The trainer requires `triton` which is Linux-only
14
+ 4. **GPU with sufficient VRAM** - 80GB recommended for the standard config. For GPUs with 32GB VRAM (e.g., RTX 5090),
15
+ use the [low VRAM config](../configs/ltx2_av_lora_low_vram.yaml) which enables INT8 quantization and other
16
+ memory optimizations
17
+
18
+ ## ⚡ Installation
19
+
20
+ First, install [uv](https://docs.astral.sh/uv/getting-started/installation/) if you haven't already.
21
+ Then clone the repository and install the dependencies:
22
+
23
+ ```bash
24
+ git clone https://github.com/Lightricks/LTX-2
25
+ ```
26
+
27
+ The `ltx-trainer` package is part of the `LTX-2` monorepo. Install the dependencies from the repository root,
28
+ then navigate to the trainer package:
29
+
30
+ ```bash
31
+ # From the repository root
32
+ uv sync
33
+ cd packages/ltx-trainer
34
+ ```
35
+
36
+ > [!NOTE]
37
+ > The trainer depends on [`ltx-core`](../../ltx-core/) and [`ltx-pipelines`](../../ltx-pipelines/)
38
+ > packages which are automatically installed from the monorepo.
39
+
40
+ ## 🏋 Training Workflow
41
+
42
+ ### 1. Prepare Your Dataset
43
+
44
+ Organize your videos and captions, then preprocess them:
45
+
46
+ ```bash
47
+ # Split long videos into scenes (optional)
48
+ uv run python scripts/split_scenes.py input.mp4 scenes_output_dir/ --filter-shorter-than 5s
49
+
50
+ # Generate captions for videos (optional)
51
+ uv run python scripts/caption_videos.py scenes_output_dir/ --output dataset.json
52
+
53
+ # Preprocess the dataset (compute latents and embeddings)
54
+ uv run python scripts/process_dataset.py dataset.json \
55
+ --resolution-buckets "960x544x49" \
56
+ --model-path /path/to/ltx-2-model.safetensors \
57
+ --text-encoder-path /path/to/gemma-model
58
+ ```
59
+
60
+ See [Dataset Preparation](dataset-preparation.md) for detailed instructions.
61
+
62
+ ### 2. Configure Training
63
+
64
+ Create or modify a configuration YAML file. Start with one of the example configs:
65
+
66
+ - [`configs/ltx2_av_lora.yaml`](../configs/ltx2_av_lora.yaml) - Audio-video LoRA training
67
+ - [`configs/ltx2_av_lora_low_vram.yaml`](../configs/ltx2_av_lora_low_vram.yaml) - Audio-video LoRA training (optimized for 32GB VRAM)
68
+ - [`configs/ltx2_v2v_ic_lora.yaml`](../configs/ltx2_v2v_ic_lora.yaml) - IC-LoRA video-to-video
69
+
70
+ Key settings to update:
71
+
72
+ ```yaml
73
+ model:
74
+ model_path: "/path/to/ltx-2-model.safetensors"
75
+ text_encoder_path: "/path/to/gemma-model"
76
+
77
+ data:
78
+ preprocessed_data_root: "/path/to/preprocessed/data"
79
+
80
+ output_dir: "outputs/my_training_run"
81
+ ```
82
+
83
+ See [Configuration Reference](configuration-reference.md) for all available options.
84
+
85
+ ### 3. Start Training
86
+
87
+ ```bash
88
+ uv run python scripts/train.py configs/ltx2_av_lora.yaml
89
+ ```
90
+
91
+ For multi-GPU training:
92
+
93
+ ```bash
94
+ uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
95
+ ```
96
+
97
+ See [Training Guide](training-guide.md) for distributed training and advanced options.
98
+
99
+ ## 🎯 Training Modes
100
+
101
+ The trainer supports several training modes:
102
+
103
+ | Mode | Description | Config Example |
104
+ |----------------------|--------------------------------|--------------------------------------------|
105
+ | **LoRA** | Efficient adapter training | `training_strategy.name: "text_to_video"` |
106
+ | **Audio-Video LoRA** | Joint audio-video training | `training_strategy.with_audio: true` |
107
+ | **IC-LoRA** | Video-to-video transformations | `training_strategy.name: "video_to_video"` |
108
+ | **Full Fine-tuning** | Full model training | `model.training_mode: "full"` |
109
+
110
+ See [Training Modes](training-modes.md) for detailed explanations,
111
+ or [Custom Training Strategies](custom-training-strategies.md) if you need to implement your own training recipe.
112
+
113
+ ## Next Steps
114
+
115
+ Once you've completed your first training run, you can:
116
+
117
+ - **Use your trained LoRA for inference** - The [`ltx-pipelines`](../../ltx-pipelines/) package provides
118
+ production-ready inference
119
+ pipelines for various use cases (T2V, I2V, IC-LoRA, etc.). See the package documentation for details.
120
+ - Learn more about [Dataset Preparation](dataset-preparation.md) for advanced preprocessing
121
+ - Explore different [Training Modes](training-modes.md) (LoRA, Audio-Video, IC-LoRA)
122
+ - Dive deeper into [Training Configuration](configuration-reference.md)
123
+ - Understand the model architecture in [LTX-Core Documentation](../../ltx-core/README.md)
124
+
125
+ ## Need Help?
126
+
127
+ If you run into issues at any step, see the [Troubleshooting Guide](troubleshooting.md) for solutions to common
128
+ problems.
129
+
130
+ Join our [Discord community](https://discord.gg/ltxplatform) for real-time help and discussion!
packages/ltx-trainer/docs/training-guide.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Guide
2
+
3
+ This guide covers how to run training jobs, from basic single-GPU training to advanced distributed setups and automatic
4
+ model uploads.
5
+
6
+ ## ⚡ Basic Training (Single GPU)
7
+
8
+ After preprocessing your dataset and preparing a configuration file, you can start training using the trainer script:
9
+
10
+ ```bash
11
+ uv run python scripts/train.py configs/ltx2_av_lora.yaml
12
+ ```
13
+
14
+ The trainer will:
15
+
16
+ 1. **Load your configuration** and validate all parameters
17
+ 2. **Initialize models** and apply optimizations
18
+ 3. **Run the training loop** with progress tracking
19
+ 4. **Generate validation videos** (if configured)
20
+ 5. **Save the trained weights** in your output directory
21
+
22
+ ### Output Files
23
+
24
+ **For LoRA training:**
25
+
26
+ - `lora_weights.safetensors` - Main LoRA weights file
27
+ - `training_config.yaml` - Copy of training configuration
28
+ - `validation_samples/` - Generated validation videos (if enabled)
29
+
30
+ **For full model fine-tuning:**
31
+
32
+ - `model_weights.safetensors` - Full model weights
33
+ - `training_config.yaml` - Copy of training configuration
34
+ - `validation_samples/` - Generated validation videos (if enabled)
35
+
36
+ ## 🖥️ Distributed / Multi-GPU Training
37
+
38
+ We use Hugging Face 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) for multi-GPU DDP and FSDP.
39
+
40
+ ### Configure Accelerate
41
+
42
+ Run the interactive wizard once to set up your environment (DDP / FSDP, GPU count, etc.):
43
+
44
+ ```bash
45
+ uv run accelerate config
46
+ ```
47
+
48
+ This stores your preferences in `~/.cache/huggingface/accelerate/default_config.yaml`.
49
+
50
+ ### Use the Provided Accelerate Configs (Recommended)
51
+
52
+ We include ready-to-use Accelerate config files in `configs/accelerate/`:
53
+
54
+ - [ddp.yaml](../configs/accelerate/ddp.yaml) — Standard DDP
55
+ - [ddp_compile.yaml](../configs/accelerate/ddp_compile.yaml) — DDP with `torch.compile` (Inductor)
56
+ - [fsdp.yaml](../configs/accelerate/fsdp.yaml) — Standard FSDP (auto-wraps `BasicAVTransformerBlock`)
57
+ - [fsdp_compile.yaml](../configs/accelerate/fsdp_compile.yaml) — FSDP with `torch.compile` (Inductor)
58
+
59
+ Launch with a specific config using `--config_file`:
60
+
61
+ ```bash
62
+ # DDP (2 GPUs shown as example)
63
+ CUDA_VISIBLE_DEVICES=0,1 \
64
+ uv run accelerate launch --config_file configs/accelerate/ddp.yaml \
65
+ scripts/train.py configs/ltx2_av_lora.yaml
66
+
67
+ # DDP + torch.compile
68
+ CUDA_VISIBLE_DEVICES=0,1 \
69
+ uv run accelerate launch --config_file configs/accelerate/ddp_compile.yaml \
70
+ scripts/train.py configs/ltx2_av_lora.yaml
71
+
72
+ # FSDP (4 GPUs shown as example)
73
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
74
+ uv run accelerate launch --config_file configs/accelerate/fsdp.yaml \
75
+ scripts/train.py configs/ltx2_av_lora.yaml
76
+
77
+ # FSDP + torch.compile
78
+ CUDA_VISIBLE_DEVICES=0,1,2,3 \
79
+ uv run accelerate launch --config_file configs/accelerate/fsdp_compile.yaml \
80
+ scripts/train.py configs/ltx2_av_lora.yaml
81
+ ```
82
+
83
+ **Notes:**
84
+
85
+ - The number of processes is taken from the Accelerate config (`num_processes`). Override with `--num_processes X` or
86
+ restrict GPUs with `CUDA_VISIBLE_DEVICES`.
87
+ - The compile variants enable `torch.compile` with the Inductor backend via Accelerate's `dynamo_config`.
88
+ - FSDP configs auto-wrap the transformer blocks (`fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock`).
89
+
90
+ ### Launch with Your Default Accelerate Config
91
+
92
+ If you prefer to use your default Accelerate profile:
93
+
94
+ ```bash
95
+ # Use settings from your default accelerate config
96
+ uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
97
+
98
+ # Override number of processes on the fly (e.g., 2 GPUs)
99
+ uv run accelerate launch --num_processes 2 scripts/train.py configs/ltx2_av_lora.yaml
100
+
101
+ # Select specific GPUs
102
+ CUDA_VISIBLE_DEVICES=0,1 uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
103
+ ```
104
+
105
+ > [!TIP]
106
+ > You can disable the in-terminal progress bars with `--disable-progress-bars` flag in the trainer CLI if desired.
107
+
108
+ ### Benefits of Distributed Training
109
+
110
+ - **Faster training**: Distribute workload across multiple GPUs
111
+ - **Larger effective batch sizes**: Combine gradients from multiple GPUs
112
+ - **Memory efficiency**: Each GPU handles a portion of the batch
113
+
114
+ > [!NOTE]
115
+ > Distributed training requires that all GPUs have sufficient memory for the model and batch size. The effective batch
116
+ > size becomes `batch_size × num_processes`.
117
+
118
+ ## 🤗 Pushing Models to Hugging Face Hub
119
+
120
+ You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration:
121
+
122
+ ```yaml
123
+ hub:
124
+ push_to_hub: true
125
+ hub_model_id: "your-username/your-model-name"
126
+ ```
127
+
128
+ ### Prerequisites
129
+
130
+ Before pushing, make sure you:
131
+
132
+ 1. **Have a Hugging Face account** - Sign up at [huggingface.co](https://huggingface.co)
133
+ 2. **Are logged in** via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable
134
+ 3. **Have write access** to the specified repository (it will be created if it doesn't exist)
135
+
136
+ ### Login Options
137
+
138
+ **Option 1: Interactive login**
139
+
140
+ ```bash
141
+ uv run huggingface-cli login
142
+ ```
143
+
144
+ **Option 2: Environment variable**
145
+
146
+ ```bash
147
+ export HUGGING_FACE_HUB_TOKEN="your_token_here"
148
+ ```
149
+
150
+ ### What Gets Uploaded
151
+
152
+ The trainer will automatically:
153
+
154
+ - **Create a model card** with training details and sample outputs
155
+ - **Upload model weights**
156
+ - **Push sample videos as GIFs** in the model card
157
+ - **Include training configuration and prompts**
158
+
159
+ ## 📊 Weights & Biases Logging
160
+
161
+ Enable experiment tracking with W&B by adding to your configuration:
162
+
163
+ ```yaml
164
+ wandb:
165
+ enabled: true
166
+ project: "ltx-2-trainer"
167
+ entity: null # Your W&B username or team
168
+ tags: [ "ltx2", "lora" ]
169
+ log_validation_videos: true
170
+ ```
171
+
172
+ This will log:
173
+
174
+ - Training loss and learning rate
175
+ - Validation videos
176
+ - Model configuration
177
+ - Training progress
178
+
179
+ ## 🚀 Next Steps
180
+
181
+ After training completes:
182
+
183
+ - **Run inference with your trained LoRA** - The [`ltx-pipelines`](../../ltx-pipelines/) package provides
184
+ production-ready inference
185
+ pipelines that support loading custom LoRAs. Available pipelines include text-to-video, image-to-video,
186
+ IC-LoRA video-to-video, and more. See the [`ltx-pipelines`](../../ltx-pipelines/) package for usage details.
187
+ - **Test your model** with validation prompts
188
+ - **Iterate and improve** based on validation results
189
+ - **Share your results** by pushing to Hugging Face Hub
190
+
191
+ ## 💡 Tips for Successful Training
192
+
193
+ - **Start small**: Begin with a small dataset and a few hundred steps to verify everything works
194
+ - **Monitor validation**: Keep an eye on validation samples to catch overfitting
195
+ - **Adjust learning rate**: Lower learning rates often produce better results
196
+ - **Use gradient checkpointing**: Essential for training with limited GPU memory
197
+ - **Save checkpoints**: Regular checkpoints help recover from interruptions
198
+
199
+ ## Need Help?
200
+
201
+ If you encounter issues during training, see the [Troubleshooting Guide](troubleshooting.md).
202
+
203
+ Join our [Discord community](https://discord.gg/ltxplatform) for real-time help!
packages/ltx-trainer/docs/training-modes.md ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Modes Guide
2
+
3
+ The trainer supports several training modes, each suited for different use cases and requirements.
4
+
5
+ ## 🎯 Standard LoRA Training (Video-Only)
6
+
7
+ Standard LoRA (Low-Rank Adaptation) training fine-tunes the model by adding small, trainable adapter layers while
8
+ keeping the base model frozen. This approach:
9
+
10
+ - **Requires significantly less memory and compute** than full fine-tuning
11
+ - **Produces small, portable weight files** (typically a few hundred MB)
12
+ - **Is ideal for learning specific styles, effects, or concepts**
13
+ - **Can be easily combined with other LoRAs** during inference
14
+
15
+ Configure standard LoRA training with:
16
+
17
+ ```yaml
18
+ model:
19
+ training_mode: "lora"
20
+
21
+ training_strategy:
22
+ name: "text_to_video"
23
+ first_frame_conditioning_p: 0.1
24
+ with_audio: false # Video-only training
25
+ ```
26
+
27
+ ## 🔊 Audio-Video LoRA Training
28
+
29
+ LTX-2 supports joint audio-video generation. You can train LoRA adapters that affect both video and audio output:
30
+
31
+ - **Synchronized audio-video generation** - Audio matches the visual content
32
+ - **Same efficient LoRA approach** - Just enable audio training
33
+ - **Requires audio latents** - Dataset must include preprocessed audio
34
+
35
+ Configure audio-video training with:
36
+
37
+ ```yaml
38
+ model:
39
+ training_mode: "lora"
40
+
41
+ training_strategy:
42
+ name: "text_to_video"
43
+ first_frame_conditioning_p: 0.1
44
+ with_audio: true # Enable audio training
45
+ audio_latents_dir: "audio_latents" # Directory containing audio latents
46
+ ```
47
+
48
+ **Example configuration file:**
49
+
50
+ - 📄 [Audio-Video LoRA Training](../configs/ltx2_av_lora.yaml)
51
+
52
+ **Dataset structure for audio-video training:**
53
+
54
+ ```
55
+ preprocessed_data_root/
56
+ ├── latents/ # Video latents
57
+ ├── conditions/ # Text embeddings
58
+ └── audio_latents/ # Audio latents (required when with_audio: true)
59
+ ```
60
+
61
+ > [!IMPORTANT]
62
+ > When training audio-video LoRAs, ensure your `target_modules` configuration captures video, audio, and
63
+ > cross-modal attention branches. Use patterns like `"to_k"` instead of `"attn1.to_k"` to match:
64
+ > - Video modules: `attn1.to_k`, `attn2.to_k`
65
+ > - Audio modules: `audio_attn1.to_k`, `audio_attn2.to_k`
66
+ > - Cross-modal modules: `audio_to_video_attn.to_k`, `video_to_audio_attn.to_k`
67
+ >
68
+ > The cross-modal attention modules (`audio_to_video_attn` and `video_to_audio_attn`) enable bidirectional
69
+ > information flow between audio and video, which is critical for synchronized audiovisual generation.
70
+ > See [Understanding Target Modules](configuration-reference.md#understanding-target-modules) for detailed guidance.
71
+
72
+ > [!NOTE]
73
+ > You can generate audio during validation even if you're not training the audio branch.
74
+ > Set `validation.generate_audio: true` independently of `training_strategy.with_audio`.
75
+
76
+ ## 🔥 Full Model Fine-tuning
77
+
78
+ Full model fine-tuning updates all parameters of the base model, providing maximum flexibility but
79
+ requiring substantial computational resources and larger training datasets:
80
+
81
+ - **Offers the highest potential quality and capability improvements**
82
+ - **Requires multiple GPUs** and distributed training techniques (e.g., FSDP)
83
+ - **Produces large checkpoint files** (several GB)
84
+ - **Best for major model adaptations** or when LoRA limitations are reached
85
+
86
+ Configure full fine-tuning with:
87
+
88
+ ```yaml
89
+ model:
90
+ training_mode: "full"
91
+
92
+ training_strategy:
93
+ name: "text_to_video"
94
+ first_frame_conditioning_p: 0.1
95
+ ```
96
+
97
+ > [!IMPORTANT]
98
+ > Full fine-tuning of LTX-2 requires multiple high-end GPUs (e.g., 4-8× H100 80GB) and distributed
99
+ > training with FSDP. See [Training Guide](training-guide.md) for multi-GPU setup instructions.
100
+
101
+ ## 🔄 In-Context LoRA (IC-LoRA) Training
102
+
103
+ IC-LoRA is a specialized training mode for video-to-video transformations.
104
+ Unlike standard training modes that learn from individual videos, IC-LoRA learns transformations from pairs of videos.
105
+ IC-LoRA enables a wide range of advanced video-to-video applications, such as:
106
+
107
+ - **Control adapters** (e.g., Depth, Pose): Learn to map from a control signal (like a depth map or pose skeleton) to a
108
+ target video
109
+ - **Video deblurring**: Transform blurry input videos into sharp, high-quality outputs
110
+ - **Style transfer**: Apply the style of a reference video to a target video sequence
111
+ - **Colorization**: Convert grayscale reference videos into colorized outputs
112
+ - **Restoration and enhancement**: Denoise, upscale, or restore old or degraded videos
113
+
114
+ By providing paired reference and target videos, IC-LoRA can learn complex transformations that go beyond caption-based
115
+ conditioning.
116
+
117
+ IC-LoRA training fundamentally differs from standard LoRA and full fine-tuning:
118
+
119
+ - **Reference videos** provide clean, unnoised conditioning input showing the "before" state
120
+ - **Target videos** are noised during training and represent the desired "after" state
121
+ - **The model learns transformations** from reference videos to target videos
122
+ - **Loss is applied only to the target portion**, not the reference
123
+ - **Training and inference time increase significantly** due to the doubled sequence length
124
+
125
+ To enable IC-LoRA training, configure your YAML file with:
126
+
127
+ ```yaml
128
+ model:
129
+ training_mode: "lora" # Required: IC-LoRA uses LoRA mode
130
+
131
+ training_strategy:
132
+ name: "video_to_video"
133
+ first_frame_conditioning_p: 0.1
134
+ reference_latents_dir: "reference_latents" # Directory for reference video latents
135
+ ```
136
+
137
+ **Example configuration file:**
138
+
139
+ - 📄 [IC-LoRA Training](../configs/ltx2_v2v_ic_lora.yaml) - Video-to-video transformation training
140
+
141
+ ### Dataset Requirements for IC-LoRA
142
+
143
+ - Your dataset must contain **paired videos** where each target video has a corresponding reference video
144
+ - Reference and target videos must have the **same frame count** (length)
145
+ - Reference videos can optionally be at **lower spatial resolution** than target videos (
146
+ see [Scaled Reference Conditioning](#scaled-reference-conditioning) below)
147
+ - Both reference and target videos should be **preprocessed** before training
148
+
149
+ **Dataset structure for IC-LoRA training:**
150
+
151
+ ```
152
+ preprocessed_data_root/
153
+ ├── latents/ # Target video latents (what the model learns to generate)
154
+ ├── conditions/ # Text embeddings for each video
155
+ └── reference_latents/ # Reference video latents (conditioning input)
156
+ ```
157
+
158
+ ### Generating Reference Videos
159
+
160
+ We provide an example script to generate reference videos (e.g., Canny edge maps) for a given dataset.
161
+ The script takes a JSON file as input (e.g., output of `caption_videos.py`) and updates it with the generated reference
162
+ video paths.
163
+
164
+ ```bash
165
+ uv run python scripts/compute_reference.py scenes_output_dir/ \
166
+ --output scenes_output_dir/dataset.json
167
+ ```
168
+
169
+ To compute a different condition (depth maps, pose skeletons, etc.), modify the `compute_reference()` function in the
170
+ script.
171
+
172
+ ### Configuration Requirements for IC-LoRA
173
+
174
+ - You **must** provide `reference_videos` in your validation configuration when using IC-LoRA training
175
+ - The number of reference videos must match the number of validation prompts
176
+
177
+ Example validation configuration for IC-LoRA:
178
+
179
+ ```yaml
180
+ validation:
181
+ prompts:
182
+ - "First prompt describing the desired output"
183
+ - "Second prompt describing the desired output"
184
+ reference_videos:
185
+ - "/path/to/reference1.mp4"
186
+ - "/path/to/reference2.mp4"
187
+ reference_downscale_factor: 1 # Set to match preprocessing (e.g., 2 for half resolution)
188
+ include_reference_in_output: true # Show reference side-by-side with output
189
+ ```
190
+
191
+ ### Scaled Reference Conditioning
192
+
193
+ For more efficient training and inference, you can use **downscaled reference videos** while keeping target videos at
194
+ full resolution. This reduces the number of conditioning tokens, leading to:
195
+
196
+ - **Faster training** due to shorter sequence lengths
197
+ - **Faster inference** with reduced memory usage
198
+ - **Same aspect ratio** maintained between reference and target
199
+
200
+ #### How It Works
201
+
202
+ When the reference video has resolution `H/n × W/n` and the target video has resolution `H × W`, the trainer
203
+ automatically detects this scale factor `n` and adjusts the positional encodings so that the reference positions
204
+ map to the correct locations in the target coordinate space.
205
+
206
+ #### Preprocessing Datasets with Scaled References
207
+
208
+ Use the `--reference-downscale-factor` option when running `process_dataset.py`:
209
+
210
+ ```bash
211
+ # Process dataset with scaled reference videos (half resolution)
212
+ uv run python scripts/process_dataset.py dataset.json \
213
+ --resolution-buckets 768x768x25 \
214
+ --model-path /path/to/ltx2.safetensors \
215
+ --text-encoder-path /path/to/gemma \
216
+ --reference-column "reference_path" \
217
+ --reference-downscale-factor 2
218
+ ```
219
+
220
+ This will:
221
+
222
+ - Process target videos at 768×768 resolution
223
+ - Process reference videos at 384×384 resolution (768 / 2)
224
+ - The trainer will automatically infer the scale factor from the dimension ratio
225
+
226
+ **Important**: Set `reference_downscale_factor: 2` in your validation configuration to match the preprocessing:
227
+
228
+ ```yaml
229
+ validation:
230
+ reference_downscale_factor: 2 # Must match the preprocessing factor
231
+ reference_videos:
232
+ - "/path/to/reference1.mp4"
233
+ - "/path/to/reference2.mp4"
234
+ ```
235
+
236
+ > [!NOTE]
237
+ > The scale factor must be a positive integer, and all dimensions must be divisible by 32.
238
+ > Common scale factors are 1 (no scaling), 2 (half resolution), or 4 (quarter resolution).
239
+
240
+ ## 📊 Training Mode Comparison
241
+
242
+ | Aspect | LoRA | Audio-Video LoRA | Full Fine-tuning | IC-LoRA |
243
+ |----------------------|--------------------------------|--------------------------------|------------------|--------------------------------|
244
+ | **Memory Usage** | Low | Low-Medium | High | Medium |
245
+ | **Training Speed** | Fast | Fast | Slow | Medium |
246
+ | **Output Size** | 100MB-few GB (depends on rank) | 100MB-few GB (depends on rank) | Tens of GB | 100MB-few GB (depends on rank) |
247
+ | **Flexibility** | Medium | Medium | High | Specialized |
248
+ | **Audio Support** | Optional | Yes | Optional | No |
249
+ | **Reference Videos** | No | No | No | Yes (required) |
250
+
251
+ ## 🎬 Using Trained Models for Inference
252
+
253
+ After training, use the [`ltx-pipelines`](../../ltx-pipelines/) package for production inference with your trained
254
+ LoRAs:
255
+
256
+ | Training Mode | Recommended Pipeline |
257
+ |-------------------------|-------------------------------------------------------|
258
+ | LoRA / Audio-Video LoRA | `TI2VidOneStagePipeline` or `TI2VidTwoStagesPipeline` |
259
+ | IC-LoRA | `ICLoraPipeline` |
260
+
261
+ All pipelines support loading custom LoRAs via the `loras` parameter. See the [`ltx-pipelines`](../../ltx-pipelines/)
262
+ package
263
+ documentation for detailed usage instructions.
264
+
265
+ ## 🚀 Next Steps
266
+
267
+ Once you've chosen your training mode:
268
+
269
+ - Set up your dataset using [Dataset Preparation](dataset-preparation.md)
270
+ - Configure your training parameters in [Configuration Reference](configuration-reference.md)
271
+ - Start training with the [Training Guide](training-guide.md)
272
+
273
+ > [!TIP]
274
+ > Need a training mode that's not covered here?
275
+ > See [Implementing Custom Training Strategies](custom-training-strategies.md)
276
+ > to learn how to create your own strategy for specialized use cases like video inpainting, audio-only training, or
277
+ > custom conditioning.
packages/ltx-trainer/docs/troubleshooting.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Troubleshooting Guide
2
+
3
+ This guide covers common issues and solutions when training with the LTX-2 trainer.
4
+
5
+ ## 🔧 VRAM and Memory Issues
6
+
7
+ Memory management is crucial for successful training with LTX-2.
8
+
9
+ > [!TIP]
10
+ > For GPUs with 32GB VRAM, use the pre-configured low VRAM config:
11
+ > [`configs/ltx2_av_lora_low_vram.yaml`](../configs/ltx2_av_lora_low_vram.yaml)
12
+ > which combines 8-bit optimizer, INT8 quantization, and reduced LoRA rank.
13
+
14
+ ### Memory Optimization Techniques
15
+
16
+ #### 1. Enable Gradient Checkpointing
17
+
18
+ Gradient checkpointing trades training speed for memory savings. **Highly recommended** for most training runs:
19
+
20
+ ```yaml
21
+ optimization:
22
+ enable_gradient_checkpointing: true
23
+ ```
24
+
25
+ #### 2. Enable 8-bit Text Encoder
26
+
27
+ Load the Gemma text encoder in 8-bit precision to save GPU memory:
28
+
29
+ ```yaml
30
+ acceleration:
31
+ load_text_encoder_in_8bit: true
32
+ ```
33
+
34
+ #### 3. Reduce Batch Size
35
+
36
+ Lower the batch size if you encounter out-of-memory errors:
37
+
38
+ ```yaml
39
+ optimization:
40
+ batch_size: 1 # Start with 1 and increase gradually
41
+ ```
42
+
43
+ Use gradient accumulation to maintain a larger effective batch size:
44
+
45
+ ```yaml
46
+ optimization:
47
+ batch_size: 1
48
+ gradient_accumulation_steps: 4 # Effective batch size = 4
49
+ ```
50
+
51
+ #### 4. Use Lower Resolution
52
+
53
+ Reduce spatial or temporal dimensions to save memory:
54
+
55
+ ```bash
56
+ # Smaller spatial resolution
57
+ uv run python scripts/process_dataset.py dataset.json \
58
+ --resolution-buckets "512x512x49" \
59
+ --model-path /path/to/model.safetensors \
60
+ --text-encoder-path /path/to/gemma
61
+
62
+ # Fewer frames
63
+ uv run python scripts/process_dataset.py dataset.json \
64
+ --resolution-buckets "960x544x25" \
65
+ --model-path /path/to/model.safetensors \
66
+ --text-encoder-path /path/to/gemma
67
+ ```
68
+
69
+ #### 5. Enable Model Quantization
70
+
71
+ Use quantization to reduce memory usage:
72
+
73
+ ```yaml
74
+ acceleration:
75
+ quantization: "int8-quanto" # Options: int8-quanto, int4-quanto, fp8-quanto
76
+ ```
77
+
78
+ #### 6. Use 8-bit Optimizer
79
+
80
+ The 8-bit AdamW optimizer uses less memory:
81
+
82
+ ```yaml
83
+ optimization:
84
+ optimizer_type: "adamw8bit"
85
+ ```
86
+
87
+ ---
88
+
89
+ ## ⚠️ Common Usage Issues
90
+
91
+ ### Issue: "No module named 'ltx_trainer'" Error
92
+
93
+ **Solution:**
94
+ Ensure you've installed the dependencies and are using `uv run` to execute scripts:
95
+
96
+ ```bash
97
+ # From the repository root
98
+ uv sync
99
+ cd packages/ltx-trainer
100
+ uv run python scripts/train.py configs/ltx2_av_lora.yaml
101
+ ```
102
+
103
+ > [!TIP]
104
+ > Always use `uv run` to execute Python scripts. This automatically uses the correct virtual environment
105
+ > without requiring manual activation.
106
+
107
+ ### Issue: "Gemma model path is not a directory" Error
108
+
109
+ **Solution:**
110
+ The `text_encoder_path` must point to a directory containing the Gemma model, not a file:
111
+
112
+ ```yaml
113
+ model:
114
+ model_path: "/path/to/ltx-2-model.safetensors" # File path
115
+ text_encoder_path: "/path/to/gemma-model/" # Directory path
116
+ ```
117
+
118
+ ### Issue: "Model path does not exist" Error
119
+
120
+ **Solution:**
121
+ LTX-2 requires local model paths. URLs are not supported:
122
+
123
+ ```yaml
124
+ # ✅ Correct - local path
125
+ model:
126
+ model_path: "/path/to/ltx-2-model.safetensors"
127
+
128
+ # ❌ Wrong - URL not supported
129
+ model:
130
+ model_path: "https://huggingface.co/..."
131
+ ```
132
+
133
+ ### Issue: "Frames must satisfy frames % 8 == 1" Error
134
+
135
+ **Solution:**
136
+ LTX-2 requires the number of frames to satisfy `frames % 8 == 1`:
137
+
138
+ - ✅ Valid: 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121
139
+ - ❌ Invalid: 24, 32, 48, 64, 100
140
+
141
+ ### Issue: Slow Training Speed
142
+
143
+ **Optimizations:**
144
+
145
+ 1. **Disable gradient checkpointing** (if you have enough VRAM):
146
+
147
+ ```yaml
148
+ optimization:
149
+ enable_gradient_checkpointing: false
150
+ ```
151
+
152
+
153
+ 2. **Use torch.compile** via Accelerate:
154
+
155
+ ```bash
156
+ uv run accelerate launch --config_file configs/accelerate/ddp_compile.yaml \
157
+ scripts/train.py configs/ltx2_av_lora.yaml
158
+ ```
159
+
160
+ ### Issue: Poor Quality Validation Outputs
161
+
162
+ **Solutions:**
163
+
164
+ 1. **Use Image-to-Video Validation:**
165
+ For more reliable validation, use image-to-video (first-frame conditioning) rather than pure text-to-video:
166
+
167
+ ```yaml
168
+ validation:
169
+ prompts:
170
+ - "a professional portrait video of a person"
171
+ images:
172
+ - "/path/to/first_frame.png" # One image per prompt
173
+ ```
174
+
175
+ 2. **Increase inference steps:**
176
+
177
+ ```yaml
178
+ validation:
179
+ inference_steps: 50 # Default is 30
180
+ ```
181
+
182
+ 3. **Adjust guidance settings:**
183
+
184
+ ```yaml
185
+ validation:
186
+ guidance_scale: 4.0 # CFG scale (recommended: 4.0)
187
+ stg_scale: 1.0 # STG scale for temporal coherence (recommended: 1.0)
188
+ stg_blocks: [29] # Transformer block to perturb
189
+ ```
190
+
191
+ 4. **Check caption quality:**
192
+ Review and manually edit captions for accuracy if using auto-generated captions.
193
+ LTX-2 prefers long, detailed captions that describe both visual content and audio (e.g., ambient sounds, speech,
194
+ music).
195
+
196
+ 5. **Check target modules:**
197
+ Ensure your `target_modules` configuration matches your training goals. For audio-video training,
198
+ use patterns that match both branches (e.g., `"to_k"` instead of `"attn1.to_k"`).
199
+ See [Understanding Target Modules](configuration-reference.md#understanding-target-modules) for details.
200
+
201
+ 6. **Adjust LoRA rank:**
202
+ Try higher values for more capacity:
203
+
204
+ ```yaml
205
+ lora:
206
+ rank: 64 # Or 128 for more capacity
207
+ ```
208
+
209
+ 7. **Increase training steps:**
210
+
211
+ ```yaml
212
+ optimization:
213
+ steps: 3000
214
+ ```
215
+
216
+ ---
217
+
218
+ ## 🔍 Debugging Tools
219
+
220
+ ### Monitor GPU Memory Usage
221
+
222
+ Track memory usage during training:
223
+
224
+ ```bash
225
+ # Watch GPU memory in real-time
226
+ watch -n 1 nvidia-smi
227
+
228
+ # Log memory usage to file
229
+ nvidia-smi --query-gpu=memory.used,memory.total --format=csv --loop=5 > memory_log.csv
230
+ ```
231
+
232
+ ### Verify Preprocessed Data
233
+
234
+ Decode latents to visualize the preprocessed videos:
235
+
236
+ ```bash
237
+ uv run python scripts/decode_latents.py dataset/.precomputed/latents debug_output \
238
+ --model-path /path/to/model.safetensors
239
+ ```
240
+
241
+ To also decode audio latents, add the `--with-audio` flag:
242
+
243
+ ```bash
244
+ uv run python scripts/decode_latents.py dataset/.precomputed/latents debug_output \
245
+ --model-path /path/to/model.safetensors \
246
+ --with-audio
247
+ ```
248
+
249
+ Compare decoded videos and audio with originals to ensure quality.
250
+
251
+ ---
252
+
253
+ ## 💡 Best Practices
254
+
255
+ ### Before Training
256
+
257
+ - [ ] Test preprocessing with a small subset first
258
+ - [ ] Verify all video files are accessible
259
+ - [ ] Check available GPU memory
260
+ - [ ] Review configuration against hardware capabilities
261
+ - [ ] Ensure model and text encoder paths are correct
262
+
263
+ ### During Training
264
+
265
+ - [ ] Monitor GPU memory usage
266
+ - [ ] Check loss convergence regularly
267
+ - [ ] Review validation samples periodically
268
+ - [ ] Save checkpoints frequently
269
+
270
+ ### After Training
271
+
272
+ - [ ] Test trained model with diverse prompts
273
+ - [ ] Document training parameters and results
274
+ - [ ] Archive training data and configs
275
+
276
+ ## 🆘 Getting Help
277
+
278
+ If you're still experiencing issues:
279
+
280
+ 1. **Check logs:** Review console output for error details
281
+ 2. **Search issues:** Look through GitHub issues for similar problems
282
+ 3. **Provide details:** When reporting issues, include:
283
+ - Hardware specifications (GPU model, VRAM)
284
+ - Configuration file used
285
+ - Complete error message
286
+ - Steps to reproduce the issue
287
+
288
+ ---
289
+
290
+ ## 🤝 Join the Community
291
+
292
+ Have questions, want to share your results, or need real-time help?
293
+ Join our [community Discord server](https://discord.gg/ltxplatform)
294
+ to connect with other users and the development team!
295
+
296
+ - Get troubleshooting help
297
+ - Share your training results and workflows
298
+ - Stay up to date with announcements and updates
299
+
300
+ We look forward to seeing you there!
packages/ltx-trainer/docs/utility-scripts.md ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utility Scripts Reference
2
+
3
+ This guide covers the various utility scripts available for preprocessing, conversion, and debugging tasks.
4
+
5
+ ## 🎬 Dataset Processing Scripts
6
+
7
+ ### Video Scene Splitting
8
+
9
+ The `scripts/split_scenes.py` script automatically splits long videos into shorter, coherent scenes.
10
+
11
+ ```bash
12
+ # Basic scene splitting
13
+ uv run python scripts/split_scenes.py input.mp4 output_dir/ --filter-shorter-than 5s
14
+ ```
15
+
16
+ **Key features:**
17
+
18
+ - **Automatic scene detection**: Uses PySceneDetect for intelligent splitting
19
+ - **Multiple algorithms**: Content-based, adaptive, threshold, and histogram detection
20
+ - **Filtering options**: Remove scenes shorter than specified duration
21
+ - **Customizable parameters**: Thresholds, window sizes, and detection modes
22
+
23
+ **Common options:**
24
+
25
+ ```bash
26
+ # See all available options
27
+ uv run python scripts/split_scenes.py --help
28
+
29
+ # Use adaptive detection with custom threshold
30
+ uv run python scripts/split_scenes.py video.mp4 scenes/ --detector adaptive --threshold 30.0
31
+
32
+ # Limit to maximum number of scenes
33
+ uv run python scripts/split_scenes.py video.mp4 scenes/ --max-scenes 50
34
+ ```
35
+
36
+ ### Automatic Video Captioning
37
+
38
+ The `scripts/caption_videos.py` script generates captions for videos (with audio) using multimodal models.
39
+
40
+ ```bash
41
+ # Generate captions for all videos in a directory (uses Qwen2.5-Omni by default)
42
+ uv run python scripts/caption_videos.py videos_dir/ --output dataset.json
43
+
44
+ # Use 8-bit quantization to reduce VRAM usage
45
+ uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --use-8bit
46
+
47
+ # Use Gemini Flash API instead (requires API key)
48
+ uv run python scripts/caption_videos.py videos_dir/ --output dataset.json \
49
+ --captioner-type gemini_flash --api-key YOUR_API_KEY
50
+
51
+ # Caption without audio processing (video-only)
52
+ uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --no-audio
53
+
54
+ # Force re-caption all files
55
+ uv run python scripts/caption_videos.py videos_dir/ --output dataset.json --override
56
+ ```
57
+
58
+ **Key features:**
59
+
60
+ - **Audio-visual captioning**: Processes both video and audio content, including speech transcription
61
+ - **Multiple backends**:
62
+ - `qwen_omni` (default): Local Qwen2.5-Omni model - processes video + audio locally
63
+ - `gemini_flash`: Google Gemini Flash API - cloud-based, requires API key
64
+ - **Structured output**: Captions include visual description, speech transcription, sounds, and on-screen text
65
+ - **Memory optimization**: 8-bit quantization option for limited VRAM
66
+ - **Incremental processing**: Skips already-captioned files by default
67
+ - **Multiple output formats**: JSON, JSONL, CSV, or TXT
68
+
69
+ **Caption format:**
70
+
71
+ The captioner produces structured captions with four sections:
72
+ - `[VISUAL]`: Detailed description of visual content
73
+ - `[SPEECH]`: Word-for-word transcription of spoken content
74
+ - `[SOUNDS]`: Description of music, ambient sounds, sound effects
75
+ - `[TEXT]`: Any on-screen text visible in the video
76
+
77
+ **Environment variables (for Gemini Flash):**
78
+
79
+ Set one of these to use Gemini Flash without passing `--api-key`:
80
+ - `GOOGLE_API_KEY`
81
+ - `GEMINI_API_KEY`
82
+
83
+ ### Dataset Preprocessing
84
+
85
+ The `scripts/process_dataset.py` script processes videos and caches latents for training.
86
+
87
+ ```bash
88
+ # Basic preprocessing
89
+ uv run python scripts/process_dataset.py dataset.json \
90
+ --resolution-buckets "960x544x49" \
91
+ --model-path /path/to/ltx-2-model.safetensors \
92
+ --text-encoder-path /path/to/gemma-model
93
+
94
+ # With audio processing
95
+ uv run python scripts/process_dataset.py dataset.json \
96
+ --resolution-buckets "960x544x49" \
97
+ --model-path /path/to/ltx-2-model.safetensors \
98
+ --text-encoder-path /path/to/gemma-model \
99
+ --with-audio
100
+
101
+ # With video decoding for verification
102
+ uv run python scripts/process_dataset.py dataset.json \
103
+ --resolution-buckets "960x544x49" \
104
+ --model-path /path/to/ltx-2-model.safetensors \
105
+ --text-encoder-path /path/to/gemma-model \
106
+ --decode
107
+ ```
108
+
109
+ Multiple resolution buckets can be specified, separated by `;`:
110
+
111
+ ```bash
112
+ uv run python scripts/process_dataset.py dataset.json \
113
+ --resolution-buckets "960x544x49;512x512x81" \
114
+ --model-path /path/to/ltx-2-model.safetensors \
115
+ --text-encoder-path /path/to/gemma-model
116
+ ```
117
+
118
+ > [!NOTE]
119
+ > When training with multiple resolution buckets, set `optimization.batch_size: 1`.
120
+
121
+ For detailed usage, see the [Dataset Preparation Guide](dataset-preparation.md).
122
+
123
+ ### Reference Video Generation
124
+
125
+ The `scripts/compute_reference.py` script provides a template for creating reference videos needed for IC-LoRA training.
126
+ The default implementation generates Canny edge reference videos.
127
+
128
+ ```bash
129
+ # Generate Canny edge reference videos
130
+ uv run python scripts/compute_reference.py videos_dir/ --output dataset.json
131
+ ```
132
+
133
+ **Key features:**
134
+
135
+ - **Canny edge detection**: Creates edge-based reference videos
136
+ - **In-place editing**: Updates existing dataset JSON files
137
+ - **Customizable**: Modify the `compute_reference()` function for different conditions (depth, pose, etc.)
138
+
139
+ > [!TIP]
140
+ > You can edit this script to generate other types of reference videos for IC-LoRA training,
141
+ > such as depth maps, segmentation masks, or any custom video transformation.
142
+
143
+ ## 🔍 Debugging and Verification Scripts
144
+
145
+ ### Latents Decoding
146
+
147
+ The `scripts/decode_latents.py` script decodes precomputed video latents back into video files for visual inspection.
148
+
149
+ ```bash
150
+ # Basic usage
151
+ uv run python scripts/decode_latents.py /path/to/latents/dir \
152
+ --output-dir /path/to/output \
153
+ --model-path /path/to/ltx-2-model.safetensors
154
+
155
+ # With VAE tiling for large videos
156
+ uv run python scripts/decode_latents.py /path/to/latents/dir \
157
+ --output-dir /path/to/output \
158
+ --model-path /path/to/ltx-2-model.safetensors \
159
+ --vae-tiling
160
+
161
+ # Decode both video and audio latents
162
+ uv run python scripts/decode_latents.py /path/to/latents/dir \
163
+ --output-dir /path/to/output \
164
+ --model-path /path/to/ltx-2-model.safetensors \
165
+ --with-audio
166
+ ```
167
+
168
+ **The script will:**
169
+
170
+ 1. **Load the VAE model** from the specified path
171
+ 2. **Process all `.pt` latent files** in the input directory
172
+ 3. **Decode each latent** back into a video using the VAE
173
+ 4. **Save resulting videos** as MP4 files in the output directory
174
+
175
+ **When to use:**
176
+
177
+ - **Verify preprocessing quality**: Check that your videos were encoded correctly
178
+ - **Debug training data**: Visualize what the model actually sees during training
179
+ - **Quality assessment**: Ensure latent encoding preserves important visual details
180
+
181
+
182
+ ### Inference Script
183
+
184
+ The `scripts/inference.py` script runs inference with a trained model.
185
+
186
+ > [!TIP]
187
+ > For production inference, consider using the [`ltx-pipelines`](../../ltx-pipelines/) package which provides optimized,
188
+ > feature-rich pipelines for various use cases:
189
+ > - **Text/Image-to-Video**: `TI2VidOneStagePipeline`, `TI2VidTwoStagesPipeline`
190
+ > - **Distilled (fast) inference**: `DistilledPipeline`
191
+ > - **IC-LoRA video-to-video**: `ICLoraPipeline`
192
+ > - **Keyframe interpolation**: `KeyframeInterpolationPipeline`
193
+ >
194
+ > All pipelines support loading custom LoRAs trained with this trainer.
195
+
196
+ ```bash
197
+ # Text-to-video inference (with audio by default)
198
+ # By default, uses CFG scale 4.0 and STG scale 1.0 with block 29
199
+ uv run python scripts/inference.py \
200
+ --checkpoint /path/to/model.safetensors \
201
+ --text-encoder-path /path/to/gemma \
202
+ --prompt "A cat playing with a ball" \
203
+ --output output.mp4
204
+
205
+ # Video-only (skip audio generation)
206
+ uv run python scripts/inference.py \
207
+ --checkpoint /path/to/model.safetensors \
208
+ --text-encoder-path /path/to/gemma \
209
+ --prompt "A cat playing with a ball" \
210
+ --skip-audio \
211
+ --output output.mp4
212
+
213
+ # Image-to-video with conditioning image
214
+ uv run python scripts/inference.py \
215
+ --checkpoint /path/to/model.safetensors \
216
+ --text-encoder-path /path/to/gemma \
217
+ --prompt "A cat walking" \
218
+ --condition-image first_frame.png \
219
+ --output output.mp4
220
+
221
+ # Custom guidance settings
222
+ uv run python scripts/inference.py \
223
+ --checkpoint /path/to/model.safetensors \
224
+ --text-encoder-path /path/to/gemma \
225
+ --prompt "A cat playing with a ball" \
226
+ --guidance-scale 4.0 \
227
+ --stg-scale 1.0 \
228
+ --stg-blocks 29 \
229
+ --output output.mp4
230
+
231
+ # Disable STG (CFG only)
232
+ uv run python scripts/inference.py \
233
+ --checkpoint /path/to/model.safetensors \
234
+ --text-encoder-path /path/to/gemma \
235
+ --prompt "A cat playing with a ball" \
236
+ --stg-scale 0.0 \
237
+ --output output.mp4
238
+ ```
239
+
240
+ **Guidance parameters:**
241
+
242
+ | Parameter | Default | Description |
243
+ |-----------|---------|-------------|
244
+ | `--guidance-scale` | 4.0 | CFG (Classifier-Free Guidance) scale |
245
+ | `--stg-scale` | 1.0 | STG (Spatio-Temporal Guidance) scale. 0.0 disables STG |
246
+ | `--stg-blocks` | 29 | Transformer block(s) to perturb for STG |
247
+ | `--stg-mode` | stg_av | `stg_av` perturbs both audio and video, `stg_v` video only |
248
+
249
+ ## 🚀 Training Scripts
250
+
251
+ ### Basic and Distributed Training
252
+
253
+ Use `scripts/train.py` for both single GPU and multi-GPU runs:
254
+
255
+ ```bash
256
+ # Single-GPU training
257
+ uv run python scripts/train.py configs/ltx2_av_lora.yaml
258
+
259
+ # Multi-GPU (uses your accelerate config)
260
+ uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
261
+
262
+ # Override number of processes
263
+ uv run accelerate launch --num_processes 4 scripts/train.py configs/ltx2_av_lora.yaml
264
+ ```
265
+
266
+ For detailed usage, see the [Training Guide](training-guide.md).
267
+
268
+ ## 💡 Tips for Using Utility Scripts
269
+
270
+ - **Start with `--help`**: Always check available options for each script
271
+ - **Test on small datasets**: Verify workflows with a few files before processing large datasets
272
+ - **Use decode verification**: Always decode a few samples to verify preprocessing quality
273
+ - **Monitor VRAM usage**: Use `--use-8bit` or quantization flags when running into memory issues
274
+ - **Keep backups**: Make copies of important dataset files before running conversion scripts
packages/ltx-trainer/scripts/caption_videos.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Auto-caption videos with audio using multimodal models.
5
+ This script provides a command-line interface for generating captions for videos
6
+ (including audio) using multimodal models. It supports:
7
+ - Qwen2.5-Omni: Local model for audio-visual captioning (default)
8
+ - Gemini Flash: Cloud-based API for audio-visual captioning
9
+ The paths to videos in the generated dataset/captions file will be RELATIVE to the
10
+ directory where the output file is stored. This makes the dataset more portable and
11
+ easier to use in different environments.
12
+ Basic usage:
13
+ # Caption a single video (includes audio by default)
14
+ caption_videos.py video.mp4 --output captions.json
15
+ # Caption all videos in a directory
16
+ caption_videos.py videos_dir/ --output captions.csv
17
+ # Caption with custom instruction
18
+ caption_videos.py video.mp4 --instruction "Describe what happens in this video in detail."
19
+ Advanced usage:
20
+ # Use Gemini Flash API (requires GEMINI_API_KEY or GOOGLE_API_KEY env var)
21
+ caption_videos.py videos_dir/ --captioner-type gemini_flash
22
+ # Disable audio processing (video-only captions)
23
+ caption_videos.py videos_dir/ --no-audio
24
+ # Process videos with specific extensions and save as JSON
25
+ caption_videos.py videos_dir/ --extensions mp4,mov,avi --output captions.json
26
+ """
27
+
28
+ import csv
29
+ import json
30
+ from enum import Enum
31
+ from pathlib import Path
32
+
33
+ import torch
34
+ import typer
35
+ from rich.console import Console
36
+ from rich.progress import (
37
+ BarColumn,
38
+ MofNCompleteColumn,
39
+ Progress,
40
+ SpinnerColumn,
41
+ TextColumn,
42
+ TimeElapsedColumn,
43
+ TimeRemainingColumn,
44
+ )
45
+ from transformers.utils.logging import disable_progress_bar
46
+
47
+ from ltx_trainer.captioning import CaptionerType, MediaCaptioningModel, create_captioner
48
+
49
+ VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"]
50
+ IMAGE_EXTENSIONS = ["jpg", "jpeg", "png"]
51
+ MEDIA_EXTENSIONS = VIDEO_EXTENSIONS + IMAGE_EXTENSIONS
52
+ SAVE_INTERVAL = 5
53
+
54
+ console = Console()
55
+ app = typer.Typer(
56
+ pretty_exceptions_enable=False,
57
+ no_args_is_help=True,
58
+ help="Auto-caption videos with audio using multimodal models.",
59
+ )
60
+
61
+ disable_progress_bar()
62
+
63
+
64
+ class OutputFormat(str, Enum):
65
+ """Available output formats for captions."""
66
+
67
+ TXT = "txt" # Separate files for captions and video paths, one caption / video path per line
68
+ CSV = "csv" # CSV file with video path and caption columns
69
+ JSON = "json" # JSON file with video paths as keys and captions as values
70
+ JSONL = "jsonl" # JSON Lines file with one JSON object per line
71
+
72
+
73
+ def caption_media(
74
+ input_path: Path,
75
+ output_path: Path,
76
+ captioner: MediaCaptioningModel,
77
+ extensions: list[str],
78
+ recursive: bool,
79
+ fps: int,
80
+ include_audio: bool,
81
+ clean_caption: bool,
82
+ output_format: OutputFormat,
83
+ override: bool,
84
+ ) -> None:
85
+ """Caption videos and images using the provided captioning model.
86
+ Args:
87
+ input_path: Path to input video file or directory
88
+ output_path: Path to output caption file
89
+ captioner: Media captioning model
90
+ extensions: List of media file extensions to include
91
+ recursive: Whether to search subdirectories recursively
92
+ fps: Frames per second to sample from videos (ignored for images)
93
+ include_audio: Whether to include audio in captioning
94
+ clean_caption: Whether to clean up captions
95
+ output_format: Format to save the captions in
96
+ override: Whether to override existing captions
97
+ """
98
+
99
+ # Get list of media files to process
100
+ media_files = _get_media_files(input_path, extensions, recursive)
101
+
102
+ if not media_files:
103
+ console.print("[bold yellow]No media files found to process.[/]")
104
+ return
105
+
106
+ console.print(f"Found [bold]{len(media_files)}[/] media files to process.")
107
+
108
+ # Load existing captions and determine which files need processing
109
+ base_dir = output_path.parent.resolve()
110
+ existing_captions = _load_existing_captions(output_path, output_format)
111
+ existing_abs_paths = {str((base_dir / p).resolve()) for p in existing_captions}
112
+
113
+ if override:
114
+ media_to_process = media_files
115
+ else:
116
+ media_to_process = [f for f in media_files if str(f.resolve()) not in existing_abs_paths]
117
+ if skipped := len(media_files) - len(media_to_process):
118
+ console.print(f"[bold yellow]Skipping {skipped} media that already have captions.[/]")
119
+
120
+ if not media_to_process:
121
+ console.print("[bold yellow]All media already have captions. Use --override to recaption.[/]")
122
+ return
123
+
124
+ # Process media files
125
+ captions = existing_captions.copy()
126
+ successfully_captioned = 0
127
+ progress = Progress(
128
+ SpinnerColumn(),
129
+ TextColumn("{task.description}"),
130
+ BarColumn(bar_width=40),
131
+ MofNCompleteColumn(),
132
+ TimeElapsedColumn(),
133
+ TextColumn("•"),
134
+ TimeRemainingColumn(),
135
+ console=console,
136
+ )
137
+
138
+ with progress:
139
+ task = progress.add_task("Captioning", total=len(media_to_process))
140
+
141
+ for i, media_file in enumerate(media_to_process):
142
+ progress.update(task, description=f"Captioning [bold blue]{media_file.name}[/]")
143
+
144
+ try:
145
+ # Generate caption for the media
146
+ caption = captioner.caption(
147
+ path=media_file,
148
+ fps=fps,
149
+ include_audio=include_audio,
150
+ clean_caption=clean_caption,
151
+ )
152
+
153
+ # Convert absolute path to relative path (relative to the output file's directory)
154
+ rel_path = str(media_file.resolve().relative_to(base_dir))
155
+ # Store the caption with the relative path as key
156
+ captions[rel_path] = caption
157
+ successfully_captioned += 1
158
+ except Exception as e:
159
+ console.print(f"[bold red]Error captioning {media_file}: {e}[/]")
160
+
161
+ if i % SAVE_INTERVAL == 0:
162
+ _save_captions(captions, output_path, output_format)
163
+
164
+ # Advance progress bar
165
+ progress.advance(task)
166
+
167
+ # Save captions to file
168
+ _save_captions(captions, output_path, output_format)
169
+
170
+ # Print summary
171
+ console.print(
172
+ f"[bold green]✓[/] Captioned [bold]{successfully_captioned}/{len(media_to_process)}[/] media successfully.",
173
+ )
174
+
175
+
176
+ def _get_media_files(
177
+ input_path: Path,
178
+ extensions: list[str] = MEDIA_EXTENSIONS,
179
+ recursive: bool = False,
180
+ ) -> list[Path]:
181
+ """Get all media files from the input path."""
182
+ input_path = Path(input_path)
183
+ # Normalize extensions to lowercase without dots
184
+ extensions_set = {ext.lower().lstrip(".") for ext in extensions}
185
+
186
+ if input_path.is_file():
187
+ # If input is a file, check if it has a valid extension
188
+ if input_path.suffix.lstrip(".").lower() in extensions_set:
189
+ return [input_path]
190
+ else:
191
+ typer.echo(f"Warning: {input_path} is not a recognized media file. Skipping.")
192
+ return []
193
+ elif input_path.is_dir():
194
+ # Find all files and filter by extension case-insensitively
195
+ glob_pattern = "**/*" if recursive else "*"
196
+ media_files = [
197
+ f for f in input_path.glob(glob_pattern) if f.is_file() and f.suffix.lstrip(".").lower() in extensions_set
198
+ ]
199
+ return sorted(media_files)
200
+ else:
201
+ typer.echo(f"Error: {input_path} does not exist.")
202
+ raise typer.Exit(code=1)
203
+
204
+
205
+ def _save_captions(
206
+ captions: dict[str, str],
207
+ output_path: Path,
208
+ format_type: OutputFormat,
209
+ ) -> None:
210
+ """Save captions to a file in the specified format.
211
+ Args:
212
+ captions: Dictionary mapping media paths to captions
213
+ output_path: Path to save the output file
214
+ format_type: Format to save the captions in
215
+ """
216
+ # Create parent directories if they don't exist
217
+ output_path.parent.mkdir(parents=True, exist_ok=True)
218
+
219
+ console.print("[bold blue]Saving captions...[/]")
220
+
221
+ match format_type:
222
+ case OutputFormat.TXT:
223
+ # Create two separate files for captions and media paths
224
+ captions_file = output_path.with_stem(f"{output_path.stem}_captions")
225
+ paths_file = output_path.with_stem(f"{output_path.stem}_paths")
226
+
227
+ with captions_file.open("w", encoding="utf-8") as f:
228
+ for caption in captions.values():
229
+ f.write(f"{caption}\n")
230
+
231
+ with paths_file.open("w", encoding="utf-8") as f:
232
+ for media_path in captions:
233
+ f.write(f"{media_path}\n")
234
+
235
+ console.print(f"[bold green]✓[/] Captions saved to [cyan]{captions_file}[/]")
236
+ console.print(f"[bold green]✓[/] Media paths saved to [cyan]{paths_file}[/]")
237
+
238
+ case OutputFormat.CSV:
239
+ with output_path.open("w", encoding="utf-8", newline="") as f:
240
+ writer = csv.writer(f)
241
+ writer.writerow(["caption", "media_path"])
242
+ for media_path, caption in captions.items():
243
+ writer.writerow([caption, media_path])
244
+
245
+ console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
246
+
247
+ case OutputFormat.JSON:
248
+ # Format as list of dictionaries with caption and media_path keys
249
+ json_data = [{"caption": caption, "media_path": media_path} for media_path, caption in captions.items()]
250
+
251
+ with output_path.open("w", encoding="utf-8") as f:
252
+ json.dump(json_data, f, indent=2, ensure_ascii=False)
253
+
254
+ console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
255
+
256
+ case OutputFormat.JSONL:
257
+ with output_path.open("w", encoding="utf-8") as f:
258
+ for media_path, caption in captions.items():
259
+ f.write(json.dumps({"caption": caption, "media_path": media_path}, ensure_ascii=False) + "\n")
260
+
261
+ console.print(f"[bold green]✓[/] Captions saved to [cyan]{output_path}[/]")
262
+
263
+ case _:
264
+ raise ValueError(f"Unsupported output format: {format_type}")
265
+
266
+
267
+ def _load_existing_captions( # noqa: PLR0912
268
+ output_path: Path,
269
+ format_type: OutputFormat,
270
+ ) -> dict[str, str]:
271
+ """Load existing captions from a file.
272
+ Args:
273
+ output_path: Path to the captions file
274
+ format_type: Format of the captions file
275
+ Returns:
276
+ Dictionary mapping media paths to captions, or empty dict if file doesn't exist
277
+ """
278
+ if not output_path.exists():
279
+ return {}
280
+
281
+ console.print(f"[bold blue]Loading existing captions from [cyan]{output_path}[/]...[/]")
282
+
283
+ existing_captions = {}
284
+
285
+ try:
286
+ match format_type:
287
+ case OutputFormat.TXT:
288
+ # For TXT format, we have two separate files
289
+ captions_file = output_path.with_stem(f"{output_path.stem}_captions")
290
+ paths_file = output_path.with_stem(f"{output_path.stem}_paths")
291
+
292
+ if captions_file.exists() and paths_file.exists():
293
+ captions = captions_file.read_text(encoding="utf-8").splitlines()
294
+ paths = paths_file.read_text(encoding="utf-8").splitlines()
295
+
296
+ if len(captions) == len(paths):
297
+ existing_captions = dict(zip(paths, captions, strict=False))
298
+
299
+ case OutputFormat.CSV:
300
+ with output_path.open("r", encoding="utf-8", newline="") as f:
301
+ reader = csv.reader(f)
302
+ # Skip header
303
+ next(reader, None)
304
+ for row in reader:
305
+ if len(row) >= 2:
306
+ caption, media_path = row[0], row[1]
307
+ existing_captions[media_path] = caption
308
+
309
+ case OutputFormat.JSON:
310
+ with output_path.open("r", encoding="utf-8") as f:
311
+ json_data = json.load(f)
312
+ for item in json_data:
313
+ if "caption" in item and "media_path" in item:
314
+ existing_captions[item["media_path"]] = item["caption"]
315
+
316
+ case OutputFormat.JSONL:
317
+ with output_path.open("r", encoding="utf-8") as f:
318
+ for line in f:
319
+ item = json.loads(line)
320
+ if "caption" in item and "media_path" in item:
321
+ existing_captions[item["media_path"]] = item["caption"]
322
+
323
+ case _:
324
+ raise ValueError(f"Unsupported output format: {format_type}")
325
+
326
+ console.print(f"[bold green]✓[/] Loaded [bold]{len(existing_captions)}[/] existing captions")
327
+ return existing_captions
328
+
329
+ except Exception as e:
330
+ console.print(f"[bold yellow]Warning: Could not load existing captions: {e}[/]")
331
+ return {}
332
+
333
+
334
+ @app.command()
335
+ def main( # noqa: PLR0913
336
+ input_path: Path = typer.Argument( # noqa: B008
337
+ ...,
338
+ help="Path to input video/image file or directory containing media files",
339
+ exists=True,
340
+ ),
341
+ output: Path | None = typer.Option( # noqa: B008
342
+ None,
343
+ "--output",
344
+ "-o",
345
+ help="Path to output file for captions. Format determined by file extension.",
346
+ ),
347
+ captioner_type: CaptionerType = typer.Option( # noqa: B008
348
+ CaptionerType.QWEN_OMNI,
349
+ "--captioner-type",
350
+ "-c",
351
+ help="Type of captioner to use. Valid values: 'qwen_omni' (local), 'gemini_flash' (API)",
352
+ case_sensitive=False,
353
+ ),
354
+ device: str | None = typer.Option(
355
+ None,
356
+ "--device",
357
+ "-d",
358
+ help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.",
359
+ ),
360
+ use_8bit: bool = typer.Option(
361
+ False,
362
+ "--use-8bit",
363
+ help="Whether to use 8-bit precision for the captioning model (reduces memory usage)",
364
+ ),
365
+ instruction: str | None = typer.Option(
366
+ None,
367
+ "--instruction",
368
+ "-i",
369
+ help="Custom instruction for the captioning model. If not provided, uses an appropriate default.",
370
+ ),
371
+ extensions: str = typer.Option(
372
+ ",".join(MEDIA_EXTENSIONS),
373
+ "--extensions",
374
+ "-e",
375
+ help="Comma-separated list of media file extensions to process",
376
+ ),
377
+ recursive: bool = typer.Option(
378
+ False,
379
+ "--recursive",
380
+ "-r",
381
+ help="Search for media files in subdirectories recursively",
382
+ ),
383
+ fps: int = typer.Option(
384
+ 3,
385
+ "--fps",
386
+ "-f",
387
+ help="Frames per second to sample from videos (ignored for images)",
388
+ ),
389
+ include_audio: bool = typer.Option(
390
+ True,
391
+ "--audio/--no-audio",
392
+ help="Whether to include audio in captioning (for videos with audio tracks)",
393
+ ),
394
+ clean_caption: bool = typer.Option(
395
+ True,
396
+ "--clean-caption/--raw-caption",
397
+ help="Whether to clean up captions by removing common VLM patterns",
398
+ ),
399
+ override: bool = typer.Option(
400
+ False,
401
+ "--override",
402
+ help="Whether to override existing captions for media",
403
+ ),
404
+ api_key: str | None = typer.Option(
405
+ None,
406
+ "--api-key",
407
+ envvar=["GOOGLE_API_KEY", "GEMINI_API_KEY"],
408
+ help="API key for Gemini Flash (can also use GOOGLE_API_KEY or GEMINI_API_KEY env var)",
409
+ ),
410
+ ) -> None:
411
+ """Auto-caption videos with audio using multimodal models.
412
+ This script supports audio-visual captioning using:
413
+ - Qwen2.5-Omni: Local model (default) - processes both video and audio
414
+ - Gemini Flash: Cloud API - requires GOOGLE_API_KEY environment variable
415
+ The paths in the output file will be relative to the output file's directory.
416
+ Examples:
417
+ # Caption videos with audio using Qwen2.5-Omni (default)
418
+ caption_videos.py videos_dir/ -o captions.json
419
+ # Caption using Gemini Flash API
420
+ caption_videos.py videos_dir/ -o captions.json -c gemini_flash
421
+ # Caption without audio (video-only)
422
+ caption_videos.py videos_dir/ -o captions.json --no-audio
423
+ # Caption with custom instruction
424
+ caption_videos.py video.mp4 -o captions.json -i "Describe this video in detail"
425
+ """
426
+
427
+ # Determine device for local models
428
+ device_str = device or ("cuda" if torch.cuda.is_available() else "cpu")
429
+
430
+ # Parse extensions
431
+ ext_list = [ext.strip() for ext in extensions.split(",")]
432
+
433
+ # Determine output path and format
434
+ if output is None:
435
+ output_format = OutputFormat.JSON
436
+ if input_path.is_file(): # noqa: SIM108
437
+ # Default to a JSON file with the same name as the input media
438
+ output = input_path.with_suffix(".dataset.json")
439
+ else:
440
+ # Default to a JSON file in the input directory
441
+ output = input_path / "dataset.json"
442
+ else:
443
+ # Determine format from file extension
444
+ output_format = OutputFormat(Path(output).suffix.lstrip(".").lower())
445
+
446
+ # Ensure output path is absolute
447
+ output = Path(output).resolve()
448
+ console.print(f"Output will be saved to [bold blue]{output}[/]")
449
+
450
+ # Initialize captioning model
451
+ with console.status("Loading captioning model...", spinner="dots"):
452
+ if captioner_type == CaptionerType.QWEN_OMNI:
453
+ captioner = create_captioner(
454
+ captioner_type=captioner_type,
455
+ device=device_str,
456
+ use_8bit=use_8bit,
457
+ instruction=instruction,
458
+ )
459
+ elif captioner_type == CaptionerType.GEMINI_FLASH:
460
+ captioner = create_captioner(
461
+ captioner_type=captioner_type,
462
+ api_key=api_key,
463
+ instruction=instruction,
464
+ )
465
+ else:
466
+ raise ValueError(f"Unsupported captioner type: {captioner_type}")
467
+
468
+ console.print(f"[bold green]✓[/] {captioner_type.value} captioning model loaded successfully")
469
+
470
+ # Caption media files
471
+ caption_media(
472
+ input_path=input_path,
473
+ output_path=output,
474
+ captioner=captioner,
475
+ extensions=ext_list,
476
+ recursive=recursive,
477
+ fps=fps,
478
+ include_audio=include_audio,
479
+ clean_caption=clean_caption,
480
+ output_format=output_format,
481
+ override=override,
482
+ )
483
+
484
+
485
+ if __name__ == "__main__":
486
+ app()
packages/ltx-trainer/scripts/compute_reference.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compute reference videos for IC-LoRA training.
3
+ This script provides a command-line interface for generating reference videos to be used for IC-LoRA training.
4
+ Note that it reads and writes to the same file (the output of caption_videos.py),
5
+ where it adds the "reference_path" field to the JSON.
6
+ Basic usage:
7
+ # Compute reference videos for all videos in a directory
8
+ compute_reference.py videos_dir/ --output videos_dir/captions.json
9
+ """
10
+
11
+ # Standard library imports
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Dict
15
+
16
+ # Third-party imports
17
+ import cv2
18
+ import torch
19
+ import torchvision.transforms.functional as TF # noqa: N812
20
+ import typer
21
+ from rich.console import Console
22
+ from rich.progress import (
23
+ BarColumn,
24
+ MofNCompleteColumn,
25
+ Progress,
26
+ SpinnerColumn,
27
+ TextColumn,
28
+ TimeElapsedColumn,
29
+ TimeRemainingColumn,
30
+ )
31
+ from transformers.utils.logging import disable_progress_bar
32
+
33
+ # Local imports
34
+ from ltx_trainer.video_utils import read_video, save_video
35
+
36
+ # Initialize console and disable progress bars
37
+ console = Console()
38
+ disable_progress_bar()
39
+
40
+
41
+ def compute_reference(
42
+ images: torch.Tensor,
43
+ ) -> torch.Tensor:
44
+ """Compute Canny edge detection on a batch of images.
45
+ Args:
46
+ images: Batch of images tensor of shape [B, C, H, W]
47
+ Returns:
48
+ Binary edge masks tensor of shape [B, H, W]
49
+ """
50
+ # Convert to grayscale if needed
51
+ if images.shape[1] == 3:
52
+ images = TF.rgb_to_grayscale(images)
53
+
54
+ # Ensure images are in [0, 1] range
55
+ if images.max() > 1.0:
56
+ images = images / 255.0
57
+
58
+ # Compute Canny edges
59
+ edge_masks = []
60
+ for image in images:
61
+ # Convert to numpy for OpenCV
62
+ image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8")
63
+
64
+ # Apply Canny edge detection
65
+ edges = cv2.Canny(
66
+ image_np,
67
+ threshold1=100,
68
+ threshold2=200,
69
+ )
70
+
71
+ # Convert back to tensor
72
+ edge_mask = torch.from_numpy(edges).float()
73
+ edge_masks.append(edge_mask)
74
+
75
+ edges = torch.stack(edge_masks)
76
+ edges = torch.stack([edges] * 3, dim=1) # Convert to 3-channel
77
+ return edges
78
+
79
+
80
+ def _get_meta_data(
81
+ output_path: Path,
82
+ ) -> Dict[str, str]:
83
+ """Get set of existing reference video paths without loading the actual files.
84
+ Args:
85
+ output_path: Path to the reference video paths file
86
+ Returns:
87
+ Dictionary mapping media paths to reference video paths
88
+ """
89
+ if not output_path.exists():
90
+ return {}
91
+
92
+ console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]")
93
+
94
+ try:
95
+ with output_path.open("r", encoding="utf-8") as f:
96
+ json_data = json.load(f)
97
+ return json_data
98
+
99
+ except Exception as e:
100
+ console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]")
101
+ return {}
102
+
103
+
104
+ def _save_dataset_json(
105
+ reference_paths: Dict[str, str],
106
+ output_path: Path,
107
+ ) -> None:
108
+ """Save dataset json with reference video paths.
109
+ Args:
110
+ reference_paths: Dictionary mapping media paths to reference video paths
111
+ output_path: Path to save the output file
112
+ """
113
+
114
+ with output_path.open("r", encoding="utf-8") as f:
115
+ json_data = json.load(f)
116
+ new_json_data = json_data.copy()
117
+ for i, item in enumerate(json_data):
118
+ media_path = item["media_path"]
119
+ reference_path = reference_paths[media_path]
120
+ new_json_data[i]["reference_path"] = reference_path
121
+
122
+ with output_path.open("w", encoding="utf-8") as f:
123
+ json.dump(new_json_data, f, indent=2, ensure_ascii=False)
124
+
125
+ console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]")
126
+ console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:")
127
+ console.print(" reference_column='[cyan]reference_path[/]'")
128
+ console.print(" video_column='[cyan]media_path[/]'")
129
+
130
+
131
+ def process_media(
132
+ input_path: Path,
133
+ output_path: Path,
134
+ override: bool,
135
+ batch_size: int = 100,
136
+ ) -> None:
137
+ """Process videos and images to compute condition on videos.
138
+ Args:
139
+ input_path: Path to input video/image file or directory
140
+ output_path: Path to output reference video file
141
+ override: Whether to override existing reference video files
142
+ """
143
+ if not output_path.exists():
144
+ raise FileNotFoundError(
145
+ f"Output file does not exist: {output_path}. This is also the input file for the dataset."
146
+ )
147
+
148
+ # Check for existing reference video files
149
+ meta_data = _get_meta_data(output_path)
150
+
151
+ base_dir = input_path.resolve()
152
+ console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths")
153
+
154
+ # Filter media files
155
+ media_to_process = []
156
+ skipped_media = []
157
+
158
+ def media_path_to_reference_path(media_file: Path) -> Path:
159
+ return media_file.parent / (media_file.stem + "_reference" + media_file.suffix)
160
+
161
+ media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data]
162
+ for media_file in media_files:
163
+ reference_path = media_path_to_reference_path(media_file)
164
+ media_to_process.append(media_file)
165
+
166
+ console.print(f"Processing [bold]{len(media_to_process)}[/] media.")
167
+
168
+ # Initialize progress tracking
169
+ progress = Progress(
170
+ SpinnerColumn(),
171
+ TextColumn("{task.description}"),
172
+ BarColumn(bar_width=40),
173
+ MofNCompleteColumn(),
174
+ TimeElapsedColumn(),
175
+ TextColumn("•"),
176
+ TimeRemainingColumn(),
177
+ console=console,
178
+ )
179
+
180
+ # Process media files
181
+ media_paths = [item["media_path"] for item in meta_data]
182
+ reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths}
183
+
184
+ with progress:
185
+ task = progress.add_task("Computing condition on videos", total=len(media_to_process))
186
+
187
+ for media_file in media_to_process:
188
+ progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]")
189
+
190
+ rel_path = str(media_file.resolve().relative_to(base_dir))
191
+ reference_path = media_path_to_reference_path(media_file)
192
+ reference_paths[rel_path] = str(reference_path.relative_to(base_dir))
193
+
194
+ if not reference_path.resolve().exists() or override:
195
+ try:
196
+ video, fps = read_video(media_file)
197
+
198
+ # Process frames in batches
199
+ condition_frames = []
200
+
201
+ for i in range(0, len(video), batch_size):
202
+ batch = video[i : i + batch_size]
203
+ condition_batch = compute_reference(batch)
204
+ condition_frames.append(condition_batch)
205
+
206
+ # Concatenate all edge frames
207
+ all_condition = torch.cat(condition_frames, dim=0)
208
+
209
+ # Save the edge video
210
+ save_video(all_condition, reference_path.resolve(), fps=fps)
211
+
212
+ except Exception as e:
213
+ console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]")
214
+ reference_paths.pop(rel_path)
215
+ else:
216
+ skipped_media.append(media_file)
217
+
218
+ progress.advance(task)
219
+
220
+ # Save results
221
+ _save_dataset_json(reference_paths, output_path)
222
+
223
+ # Print summary
224
+ total_to_process = len(media_files) - len(skipped_media)
225
+ console.print(
226
+ f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.",
227
+ )
228
+
229
+
230
+ app = typer.Typer(
231
+ pretty_exceptions_enable=False,
232
+ no_args_is_help=True,
233
+ help="Compute reference videos for IC-LoRA training.",
234
+ )
235
+
236
+
237
+ @app.command()
238
+ def main(
239
+ input_path: Path = typer.Argument( # noqa: B008
240
+ ...,
241
+ help="Path to input video/image file or directory containing media files",
242
+ exists=True,
243
+ ),
244
+ output: Path | None = typer.Option( # noqa: B008
245
+ None,
246
+ "--output",
247
+ "-o",
248
+ help="Path to json output file for reference video paths. "
249
+ "This is also the input file for the dataset, the output of compute_captions.py.",
250
+ ),
251
+ override: bool = typer.Option(
252
+ False,
253
+ "--override",
254
+ help="Whether to override existing reference video files",
255
+ ),
256
+ batch_size: int = typer.Option(
257
+ 100,
258
+ "--batch-size",
259
+ help="Batch size for processing videos",
260
+ ),
261
+ ) -> None:
262
+ """Compute reference videos for IC-LoRA training.
263
+ This script generates reference videos (e.g., Canny edge maps) for given videos.
264
+ The paths in the output file will be relative to the output file's directory.
265
+ Examples:
266
+ # Process all videos in a directory
267
+ compute_reference.py videos_dir/ -o videos_dir/captions.json
268
+ """
269
+
270
+ # Ensure output path is absolute
271
+ output = Path(output).resolve()
272
+ console.print(f"Output will be saved to [bold blue]{output}[/]")
273
+
274
+ # Verify output path exists
275
+ if not output.exists():
276
+ raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.")
277
+
278
+ # Process media files
279
+ process_media(
280
+ input_path=input_path,
281
+ output_path=output,
282
+ override=override,
283
+ batch_size=batch_size,
284
+ )
285
+
286
+
287
+ if __name__ == "__main__":
288
+ app()
packages/ltx-trainer/scripts/decode_latents.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Decode precomputed video latents back into videos using the VAE.
5
+ This script loads latent files saved during preprocessing and decodes them
6
+ back into video clips using the same VAE model.
7
+ Basic usage:
8
+ python scripts/decode_latents.py /path/to/latents/dir /path/to/output \
9
+ --model-source /path/to/ltx2.safetensors
10
+ """
11
+
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ import torchaudio
16
+ import torchvision.utils
17
+ import typer
18
+ from einops import rearrange
19
+ from rich.console import Console
20
+ from rich.progress import (
21
+ BarColumn,
22
+ MofNCompleteColumn,
23
+ Progress,
24
+ SpinnerColumn,
25
+ TextColumn,
26
+ TimeElapsedColumn,
27
+ TimeRemainingColumn,
28
+ )
29
+ from transformers.utils.logging import disable_progress_bar
30
+
31
+ from ltx_core.model.video_vae import SpatialTilingConfig, TemporalTilingConfig, TilingConfig
32
+ from ltx_trainer import logger
33
+ from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder
34
+ from ltx_trainer.video_utils import save_video
35
+
36
+ DEFAULT_TILE_SIZE_PIXELS = 512 # Spatial tile size in pixels (must be ≥64 and divisible by 32)
37
+ DEFAULT_TILE_OVERLAP_PIXELS = 128 # Spatial tile overlap in pixels (must be divisible by 32)
38
+ DEFAULT_TILE_SIZE_FRAMES = 128 # Temporal tile size in frames (must be ≥16 and divisible by 8)
39
+ DEFAULT_TILE_OVERLAP_FRAMES = 24 # Temporal tile overlap in frames (must be divisible by 8)
40
+
41
+ disable_progress_bar()
42
+ console = Console()
43
+ app = typer.Typer(
44
+ pretty_exceptions_enable=False,
45
+ no_args_is_help=True,
46
+ help="Decode precomputed video latents back into videos using the VAE.",
47
+ )
48
+
49
+
50
+ class LatentsDecoder:
51
+ def __init__(
52
+ self,
53
+ model_path: str,
54
+ device: str = "cuda",
55
+ vae_tiling: bool = False,
56
+ with_audio: bool = False,
57
+ ):
58
+ """Initialize the decoder with model configuration.
59
+ Args:
60
+ model_path: Path to LTX-2 checkpoint (.safetensors)
61
+ device: Device to use for computation
62
+ vae_tiling: Whether to enable VAE tiling for larger video resolutions
63
+ with_audio: Whether to load audio VAE for audio decoding
64
+ """
65
+ self.device = torch.device(device)
66
+ self.model_path = model_path
67
+ self.vae = None
68
+ self.audio_vae = None
69
+ self.vocoder = None
70
+ self.vae_tiling = vae_tiling
71
+
72
+ self._load_model(model_path, with_audio)
73
+
74
+ def _load_model(self, model_path: str, with_audio: bool = False) -> None:
75
+ """Initialize and load the VAE model(s)."""
76
+ with console.status(f"[bold]Loading video VAE decoder from {model_path}...", spinner="dots"):
77
+ self.vae = load_video_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)
78
+
79
+ if with_audio:
80
+ with console.status(f"[bold]Loading audio VAE decoder from {model_path}...", spinner="dots"):
81
+ self.audio_vae = load_audio_vae_decoder(model_path, device=self.device, dtype=torch.bfloat16)
82
+
83
+ with console.status(f"[bold]Loading vocoder from {model_path}...", spinner="dots"):
84
+ self.vocoder = load_vocoder(model_path, device=self.device)
85
+
86
+ @torch.inference_mode()
87
+ def decode(self, latents_dir: Path, output_dir: Path, seed: int | None = None) -> None:
88
+ """Decode all latent files in the directory recursively.
89
+ Args:
90
+ latents_dir: Directory containing latent files (.pt)
91
+ output_dir: Directory to save decoded videos
92
+ seed: Optional random seed for noise generation
93
+ """
94
+ # Find all .pt files recursively
95
+ latent_files = list(latents_dir.rglob("*.pt"))
96
+
97
+ if not latent_files:
98
+ logger.warning(f"No .pt files found in {latents_dir}")
99
+ return
100
+
101
+ logger.info(f"Found {len(latent_files):,} latent files to decode")
102
+
103
+ # Process files with progress bar
104
+ with Progress(
105
+ SpinnerColumn(),
106
+ TextColumn("[progress.description]{task.description}"),
107
+ BarColumn(),
108
+ MofNCompleteColumn(),
109
+ TimeElapsedColumn(),
110
+ TimeRemainingColumn(),
111
+ console=console,
112
+ ) as progress:
113
+ task = progress.add_task("Decoding latents", total=len(latent_files))
114
+
115
+ for latent_file in latent_files:
116
+ # Calculate relative path to maintain directory structure
117
+ rel_path = latent_file.relative_to(latents_dir)
118
+ output_subdir = output_dir / rel_path.parent
119
+ output_subdir.mkdir(parents=True, exist_ok=True)
120
+
121
+ try:
122
+ self._process_file(latent_file, output_subdir, seed)
123
+ except Exception as e:
124
+ logger.error(f"Error processing {latent_file}: {e}")
125
+ continue
126
+
127
+ progress.advance(task)
128
+
129
+ logger.info(f"Decoding complete! Videos saved to {output_dir}")
130
+
131
+ @torch.inference_mode()
132
+ def decode_audio(self, latents_dir: Path, output_dir: Path) -> None:
133
+ """Decode all audio latent files in the directory recursively.
134
+ Args:
135
+ latents_dir: Directory containing audio latent files (.pt)
136
+ output_dir: Directory to save decoded audio files
137
+ """
138
+ # Check if audio VAE is loaded
139
+ if self.audio_vae is None or self.vocoder is None:
140
+ logger.warning("Audio VAE or vocoder not loaded. Skipping audio decoding.")
141
+ return
142
+
143
+ # Find all .pt files recursively
144
+ latent_files = list(latents_dir.rglob("*.pt"))
145
+
146
+ if not latent_files:
147
+ logger.warning(f"No .pt files found in {latents_dir}")
148
+ return
149
+
150
+ logger.info(f"Found {len(latent_files):,} audio latent files to decode")
151
+
152
+ # Process files with progress bar
153
+ with Progress(
154
+ SpinnerColumn(),
155
+ TextColumn("[progress.description]{task.description}"),
156
+ BarColumn(),
157
+ MofNCompleteColumn(),
158
+ TimeElapsedColumn(),
159
+ TimeRemainingColumn(),
160
+ console=console,
161
+ ) as progress:
162
+ task = progress.add_task("Decoding audio latents", total=len(latent_files))
163
+
164
+ for latent_file in latent_files:
165
+ # Calculate relative path to maintain directory structure
166
+ rel_path = latent_file.relative_to(latents_dir)
167
+ output_subdir = output_dir / rel_path.parent
168
+ output_subdir.mkdir(parents=True, exist_ok=True)
169
+
170
+ try:
171
+ self._process_audio_file(latent_file, output_subdir)
172
+ except Exception as e:
173
+ logger.error(f"Error processing audio {latent_file}: {e}")
174
+ continue
175
+
176
+ progress.advance(task)
177
+
178
+ logger.info(f"Audio decoding complete! Audio files saved to {output_dir}")
179
+
180
+ def _process_file(self, latent_file: Path, output_dir: Path, seed: int | None) -> None:
181
+ """Process a single latent file."""
182
+ # Load the latent data
183
+ data = torch.load(latent_file, map_location=self.device, weights_only=False)
184
+
185
+ # Get latents - handle both old patchified [seq_len, C] and new [C, F, H, W] formats
186
+ latents = data["latents"]
187
+ num_frames = data["num_frames"]
188
+ height = data["height"]
189
+ width = data["width"]
190
+
191
+ # Check if latents need reshaping (old patchified format)
192
+ if latents.dim() == 2:
193
+ # Old format: [seq_len, C] -> reshape to [C, F, H, W]
194
+ latents = rearrange(latents, "(f h w) c -> c f h w", f=num_frames, h=height, w=width)
195
+
196
+ # Add batch dimension: [C, F, H, W] -> [1, C, F, H, W]
197
+ latents = latents.unsqueeze(0).to(device=self.device, dtype=torch.bfloat16)
198
+
199
+ # Create generator only if seed is provided
200
+ generator = None
201
+ if seed is not None:
202
+ generator = torch.Generator(device=self.device)
203
+ generator.manual_seed(seed)
204
+
205
+ # Decode the video
206
+ video = self._decode_video(latents, generator)
207
+
208
+ # Determine output format and save
209
+ is_image = video.shape[0] == 1
210
+ if is_image:
211
+ # Save as PNG for single frame
212
+ output_path = output_dir / f"{latent_file.stem}.png"
213
+ torchvision.utils.save_image(
214
+ video[0], # [C, H, W] in [0, 1]
215
+ str(output_path),
216
+ )
217
+ else:
218
+ # Save as MP4 for video using PyAV-based save_video
219
+ output_path = output_dir / f"{latent_file.stem}.mp4"
220
+ fps = data.get("fps", 24) # Use stored FPS or default to 24
221
+ save_video(
222
+ video_tensor=video, # [F, C, H, W] in [0, 1]
223
+ output_path=output_path,
224
+ fps=fps,
225
+ )
226
+
227
+ def _decode_video(self, latents: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
228
+ """Decode latents to video frames."""
229
+ if self.vae_tiling:
230
+ # Use tiled decoding for reduced VRAM
231
+ tiling_config = TilingConfig(
232
+ spatial_config=SpatialTilingConfig(
233
+ tile_size_in_pixels=DEFAULT_TILE_SIZE_PIXELS,
234
+ tile_overlap_in_pixels=DEFAULT_TILE_OVERLAP_PIXELS,
235
+ ),
236
+ temporal_config=TemporalTilingConfig(
237
+ tile_size_in_frames=DEFAULT_TILE_SIZE_FRAMES,
238
+ tile_overlap_in_frames=DEFAULT_TILE_OVERLAP_FRAMES,
239
+ ),
240
+ )
241
+ chunks = list(
242
+ self.vae.tiled_decode(
243
+ latents,
244
+ tiling_config=tiling_config,
245
+ generator=generator,
246
+ )
247
+ )
248
+ # Concatenate along temporal dimension
249
+ video = torch.cat(chunks, dim=2) # [B, C, F, H, W]
250
+ else:
251
+ # Standard full decoding
252
+ video = self.vae(latents, generator=generator) # [B, C, F, H, W]
253
+
254
+ # Convert to [F, C, H, W] format and normalize to [0, 1]
255
+ video = rearrange(video, "1 c f h w -> f c h w")
256
+ video = (video + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
257
+ video = video.clamp(0, 1)
258
+
259
+ return video
260
+
261
+ def _process_audio_file(self, latent_file: Path, output_dir: Path) -> None:
262
+ """Process a single audio latent file."""
263
+ # Load the latent data
264
+ data = torch.load(latent_file, map_location=self.device, weights_only=False)
265
+
266
+ latents = data["latents"].to(device=self.device, dtype=torch.float32)
267
+ num_time_steps = data["num_time_steps"]
268
+ freq_bins = data["frequency_bins"]
269
+
270
+ # Handle both old patchified [seq_len, C] and new [C, T, F] formats
271
+ if latents.dim() == 2:
272
+ # Old format: [seq_len, channels] where seq_len = time * freq
273
+ # Reshape to [C, T, F]
274
+ latents = rearrange(latents, "(t f) c -> c t f", t=num_time_steps, f=freq_bins)
275
+
276
+ # Add batch dimension: [C, T, F] -> [1, C, T, F]
277
+ latents = latents.unsqueeze(0)
278
+
279
+ # Set correct dtype for audio VAE
280
+ latents = latents.to(dtype=torch.bfloat16)
281
+
282
+ # Decode audio using audio VAE decoder (produces mel spectrogram)
283
+ mel_spectrogram = self.audio_vae(latents)
284
+
285
+ # Convert mel spectrogram to waveform using vocoder
286
+ waveform = self.vocoder(mel_spectrogram)
287
+
288
+ # Save as WAV
289
+ output_path = output_dir / f"{latent_file.stem}.wav"
290
+ sample_rate = self.vocoder.output_sampling_rate
291
+ torchaudio.save(str(output_path), waveform[0].cpu(), sample_rate)
292
+
293
+
294
+ @app.command()
295
+ def main(
296
+ latents_dir: str = typer.Argument(
297
+ ...,
298
+ help="Directory containing the precomputed latent files (searched recursively)",
299
+ ),
300
+ output_dir: str = typer.Argument(
301
+ ...,
302
+ help="Directory to save the decoded videos (maintains same folder hierarchy as input)",
303
+ ),
304
+ model_path: str = typer.Option(
305
+ ...,
306
+ help="Path to LTX-2 checkpoint (.safetensors file)",
307
+ ),
308
+ device: str = typer.Option(
309
+ default="cuda",
310
+ help="Device to use for computation",
311
+ ),
312
+ vae_tiling: bool = typer.Option(
313
+ default=False,
314
+ help="Enable VAE tiling for larger video resolutions",
315
+ ),
316
+ seed: int | None = typer.Option(
317
+ default=None,
318
+ help="Random seed for noise generation during decoding",
319
+ ),
320
+ with_audio: bool = typer.Option(
321
+ default=False,
322
+ help="Also decode audio latents (requires audio_latents directory)",
323
+ ),
324
+ audio_latents_dir: str | None = typer.Option(
325
+ default=None,
326
+ help="Directory containing audio latent files (defaults to 'audio_latents' sibling of latents_dir)",
327
+ ),
328
+ ) -> None:
329
+ """Decode precomputed video latents back into videos using the VAE.
330
+ This script recursively searches for .pt latent files in the input directory
331
+ and decodes them to videos, maintaining the same folder hierarchy in the output.
332
+ Examples:
333
+ # Basic usage
334
+ python scripts/decode_latents.py /path/to/latents /path/to/videos \\
335
+ --model-path /path/to/ltx2.safetensors
336
+ # With VAE tiling for large videos
337
+ python scripts/decode_latents.py /path/to/latents /path/to/videos \\
338
+ --model-path /path/to/ltx2.safetensors --vae-tiling
339
+ # With audio decoding
340
+ python scripts/decode_latents.py /path/to/latents /path/to/videos \\
341
+ --model-path /path/to/ltx2.safetensors --with-audio
342
+ """
343
+ latents_path = Path(latents_dir)
344
+ output_path = Path(output_dir)
345
+
346
+ if not latents_path.exists() or not latents_path.is_dir():
347
+ raise typer.BadParameter(f"Latents directory does not exist: {latents_path}")
348
+
349
+ decoder = LatentsDecoder(
350
+ model_path=model_path,
351
+ device=device,
352
+ vae_tiling=vae_tiling,
353
+ with_audio=with_audio,
354
+ )
355
+ decoder.decode(latents_path, output_path, seed=seed)
356
+
357
+ # Decode audio if requested
358
+ if with_audio:
359
+ audio_path = Path(audio_latents_dir) if audio_latents_dir else latents_path.parent / "audio_latents"
360
+
361
+ if audio_path.exists():
362
+ audio_output_path = output_path.parent / "decoded_audio"
363
+ decoder.decode_audio(audio_path, audio_output_path)
364
+ else:
365
+ logger.warning(f"Audio latents directory not found: {audio_path}")
366
+
367
+
368
+ if __name__ == "__main__":
369
+ app()
packages/ltx-trainer/scripts/process_captions.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Compute text embeddings for video generation training.
5
+ This module provides functionality for processing text captions, including:
6
+ - Loading captions from various file formats (CSV, JSON, JSONL)
7
+ - Cleaning and preprocessing text (removing LLM prefixes, adding ID tokens)
8
+ - CaptionsDataset for caption-only preprocessing workflows
9
+ Can be used as a standalone script:
10
+ python scripts/process_captions.py dataset.json --output-dir /path/to/output \
11
+ --model-source /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma
12
+ """
13
+
14
+ import json
15
+ import os
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import pandas as pd
20
+ import torch
21
+ import typer
22
+ from rich.console import Console
23
+ from rich.progress import (
24
+ BarColumn,
25
+ MofNCompleteColumn,
26
+ Progress,
27
+ SpinnerColumn,
28
+ TaskProgressColumn,
29
+ TextColumn,
30
+ TimeElapsedColumn,
31
+ TimeRemainingColumn,
32
+ )
33
+ from torch.utils.data import DataLoader, Dataset
34
+ from transformers.utils.logging import disable_progress_bar
35
+
36
+ from ltx_trainer import logger
37
+ from ltx_trainer.model_loader import load_embeddings_processor, load_text_encoder
38
+
39
+ # Disable tokenizers parallelism to avoid warnings
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
+
42
+ disable_progress_bar()
43
+
44
+ # Common phrases that LLMs often add to captions that we might want to remove
45
+ COMMON_BEGINNING_PHRASES: tuple[str, ...] = (
46
+ "This video",
47
+ "The video",
48
+ "This clip",
49
+ "The clip",
50
+ "The animation",
51
+ "This image",
52
+ "The image",
53
+ "This picture",
54
+ "The picture",
55
+ )
56
+
57
+ COMMON_CONTINUATION_WORDS: tuple[str, ...] = (
58
+ "shows",
59
+ "depicts",
60
+ "features",
61
+ "captures",
62
+ "highlights",
63
+ "introduces",
64
+ "presents",
65
+ )
66
+
67
+ COMMON_LLM_START_PHRASES: tuple[str, ...] = (
68
+ "In the video,",
69
+ "In this video,",
70
+ "In this video clip,",
71
+ "In the clip,",
72
+ "Caption:",
73
+ *(
74
+ f"{beginning} {continuation}"
75
+ for beginning in COMMON_BEGINNING_PHRASES
76
+ for continuation in COMMON_CONTINUATION_WORDS
77
+ ),
78
+ )
79
+
80
+ app = typer.Typer(
81
+ pretty_exceptions_enable=False,
82
+ no_args_is_help=True,
83
+ help="Process text captions and save embeddings for video generation training.",
84
+ )
85
+
86
+
87
+ class CaptionsDataset(Dataset):
88
+ """
89
+ Dataset for processing text captions only.
90
+ This dataset is designed for caption preprocessing workflows where you only need
91
+ to process text without loading videos. Useful for:
92
+ - Precomputing text embeddings
93
+ - Caption cleaning and preprocessing
94
+ - Text-only preprocessing pipelines
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ dataset_file: str | Path,
100
+ caption_column: str,
101
+ media_column: str = "media_path",
102
+ lora_trigger: str | None = None,
103
+ remove_llm_prefixes: bool = False,
104
+ ) -> None:
105
+ """
106
+ Initialize the captions dataset.
107
+ Args:
108
+ dataset_file: Path to CSV/JSON/JSONL metadata file
109
+ caption_column: Column name for captions in the metadata file
110
+ media_column: Column name for media paths (used for output naming)
111
+ lora_trigger: Optional trigger word to prepend to each caption
112
+ remove_llm_prefixes: Whether to remove common LLM-generated prefixes
113
+ """
114
+ super().__init__()
115
+
116
+ self.dataset_file = Path(dataset_file)
117
+ self.caption_column = caption_column
118
+ self.media_column = media_column
119
+ self.lora_trigger = f"{lora_trigger.strip()} " if lora_trigger else ""
120
+
121
+ # Load captions with their corresponding output embedding paths
122
+ self.caption_data = self._load_caption_data()
123
+
124
+ # Convert to lists for indexing
125
+ self.output_paths = list(self.caption_data.keys())
126
+ self.prompts = list(self.caption_data.values())
127
+
128
+ # Clean LLM start phrases if requested
129
+ if remove_llm_prefixes:
130
+ self._clean_llm_prefixes()
131
+
132
+ def __len__(self) -> int:
133
+ return len(self.prompts)
134
+
135
+ def __getitem__(self, index: int) -> dict[str, Any]:
136
+ """Get a single caption with optional trigger word prepended and output path."""
137
+ prompt = self.lora_trigger + self.prompts[index]
138
+ return {
139
+ "prompt": prompt,
140
+ "output_path": self.output_paths[index],
141
+ "index": index,
142
+ }
143
+
144
+ def _load_caption_data(self) -> dict[str, str]:
145
+ """Load captions and compute their output embedding paths."""
146
+ if self.dataset_file.suffix == ".csv":
147
+ return self._load_caption_data_from_csv()
148
+ elif self.dataset_file.suffix == ".json":
149
+ return self._load_caption_data_from_json()
150
+ elif self.dataset_file.suffix == ".jsonl":
151
+ return self._load_caption_data_from_jsonl()
152
+ else:
153
+ raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.")
154
+
155
+ def _load_caption_data_from_csv(self) -> dict[str, str]:
156
+ """Load captions from a CSV file and compute output embedding paths."""
157
+ df = pd.read_csv(self.dataset_file)
158
+
159
+ if self.caption_column not in df.columns:
160
+ raise ValueError(f"Column '{self.caption_column}' not found in CSV file")
161
+ if self.media_column not in df.columns:
162
+ raise ValueError(f"Column '{self.media_column}' not found in CSV file")
163
+
164
+ caption_data = {}
165
+ for _, row in df.iterrows():
166
+ media_path = Path(row[self.media_column].strip())
167
+ # Convert media path to embedding output path (same structure, .pt extension)
168
+ output_path = str(media_path.with_suffix(".pt"))
169
+ caption_data[output_path] = row[self.caption_column]
170
+
171
+ return caption_data
172
+
173
+ def _load_caption_data_from_json(self) -> dict[str, str]:
174
+ """Load captions from a JSON file and compute output embedding paths."""
175
+ with open(self.dataset_file, "r", encoding="utf-8") as file:
176
+ data = json.load(file)
177
+
178
+ if not isinstance(data, list):
179
+ raise ValueError("JSON file must contain a list of objects")
180
+
181
+ caption_data = {}
182
+ for entry in data:
183
+ if self.caption_column not in entry:
184
+ raise ValueError(f"Key '{self.caption_column}' not found in JSON entry: {entry}")
185
+ if self.media_column not in entry:
186
+ raise ValueError(f"Key '{self.media_column}' not found in JSON entry: {entry}")
187
+
188
+ media_path = Path(entry[self.media_column].strip())
189
+ # Convert media path to embedding output path (same structure, .pt extension)
190
+ output_path = str(media_path.with_suffix(".pt"))
191
+ caption_data[output_path] = entry[self.caption_column]
192
+
193
+ return caption_data
194
+
195
+ def _load_caption_data_from_jsonl(self) -> dict[str, str]:
196
+ """Load captions from a JSONL file and compute output embedding paths."""
197
+ caption_data = {}
198
+ with open(self.dataset_file, "r", encoding="utf-8") as file:
199
+ for line in file:
200
+ entry = json.loads(line)
201
+ if self.caption_column not in entry:
202
+ raise ValueError(f"Key '{self.caption_column}' not found in JSONL entry: {entry}")
203
+ if self.media_column not in entry:
204
+ raise ValueError(f"Key '{self.media_column}' not found in JSONL entry: {entry}")
205
+
206
+ media_path = Path(entry[self.media_column].strip())
207
+ # Convert media path to embedding output path (same structure, .pt extension)
208
+ output_path = str(media_path.with_suffix(".pt"))
209
+ caption_data[output_path] = entry[self.caption_column]
210
+
211
+ return caption_data
212
+
213
+ def _clean_llm_prefixes(self) -> None:
214
+ """Remove common LLM-generated prefixes from captions."""
215
+ for i in range(len(self.prompts)):
216
+ self.prompts[i] = self.prompts[i].strip()
217
+ for phrase in COMMON_LLM_START_PHRASES:
218
+ if self.prompts[i].startswith(phrase):
219
+ self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()
220
+ break
221
+
222
+
223
+ def compute_captions_embeddings( # noqa: PLR0913
224
+ dataset_file: str | Path,
225
+ output_dir: str,
226
+ model_path: str,
227
+ text_encoder_path: str,
228
+ caption_column: str = "caption",
229
+ media_column: str = "media_path",
230
+ lora_trigger: str | None = None,
231
+ remove_llm_prefixes: bool = False,
232
+ batch_size: int = 8,
233
+ device: str = "cuda",
234
+ load_in_8bit: bool = False,
235
+ ) -> None:
236
+ """
237
+ Process captions and save text embeddings.
238
+ Args:
239
+ dataset_file: Path to metadata file (CSV/JSON/JSONL) containing captions and media paths
240
+ output_dir: Directory to save embeddings
241
+ model_path: Path to LTX-2 checkpoint (.safetensors)
242
+ text_encoder_path: Path to Gemma text encoder directory
243
+ caption_column: Column name containing captions in the metadata file
244
+ media_column: Column name containing media paths (used for output naming)
245
+ lora_trigger: Optional trigger word to prepend to each caption
246
+ remove_llm_prefixes: Whether to remove common LLM-generated prefixes
247
+ batch_size: Batch size for processing
248
+ device: Device to use for computation
249
+ load_in_8bit: Whether to load the Gemma text encoder in 8-bit precision
250
+ """
251
+
252
+ console = Console()
253
+
254
+ # Create dataset
255
+ dataset = CaptionsDataset(
256
+ dataset_file=dataset_file,
257
+ caption_column=caption_column,
258
+ media_column=media_column,
259
+ lora_trigger=lora_trigger,
260
+ remove_llm_prefixes=remove_llm_prefixes,
261
+ )
262
+ logger.info(f"Loaded {len(dataset):,} captions")
263
+
264
+ output_path = Path(output_dir)
265
+ output_path.mkdir(parents=True, exist_ok=True)
266
+
267
+ # Load text encoder and embeddings processor
268
+ with console.status("[bold]Loading Gemma text encoder...", spinner="dots"):
269
+ text_encoder = load_text_encoder(
270
+ text_encoder_path,
271
+ device=device,
272
+ dtype=torch.bfloat16,
273
+ load_in_8bit=load_in_8bit,
274
+ )
275
+ embeddings_processor = load_embeddings_processor(
276
+ model_path,
277
+ device=device,
278
+ dtype=torch.bfloat16,
279
+ )
280
+
281
+ logger.info("Text encoder and embeddings processor loaded successfully")
282
+
283
+ # TODO(batch-tokenization): The current Gemma tokenizer doesn't support batched tokenization.
284
+ if batch_size > 1:
285
+ logger.warning(
286
+ "Batch size greater than 1 is not currently supported with the Gemma tokenizer. "
287
+ "Overriding batch_size to 1. This will be fixed in a future update."
288
+ )
289
+ batch_size = 1
290
+
291
+ # Create dataloader
292
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
293
+
294
+ # Process batches
295
+ total_batches = len(dataloader)
296
+ logger.info(f"Processing captions in {total_batches:,} batches...")
297
+
298
+ with Progress(
299
+ SpinnerColumn(),
300
+ TextColumn("[progress.description]{task.description}"),
301
+ BarColumn(),
302
+ TaskProgressColumn(),
303
+ MofNCompleteColumn(),
304
+ TimeElapsedColumn(),
305
+ TimeRemainingColumn(),
306
+ console=console,
307
+ ) as progress:
308
+ task = progress.add_task("Processing captions", total=len(dataloader))
309
+ for batch in dataloader:
310
+ # Encode prompts using text_encoder.encode() + feature_extractor
311
+ # (returns video/audio features before connector).
312
+ # The connector is applied during training via embeddings_processor
313
+ with torch.inference_mode():
314
+ # TODO(batch-tokenization): When tokenizer supports batching, encode all prompts at once.
315
+ # For now, process one at a time:
316
+ for i in range(len(batch["prompt"])):
317
+ hidden_states, prompt_attention_mask = text_encoder.encode(batch["prompt"][i], padding_side="left")
318
+ video_prompt_embeds, audio_prompt_embeds = embeddings_processor.feature_extractor(
319
+ hidden_states, prompt_attention_mask, "left"
320
+ )
321
+
322
+ output_rel_path = Path(batch["output_path"][i])
323
+
324
+ # Create output directory maintaining structure
325
+ output_dir_path = output_path / output_rel_path.parent
326
+ output_dir_path.mkdir(parents=True, exist_ok=True)
327
+
328
+ embedding_data = {
329
+ "video_prompt_embeds": video_prompt_embeds[0].cpu().contiguous(),
330
+ "prompt_attention_mask": prompt_attention_mask[0].cpu().contiguous(),
331
+ }
332
+ if audio_prompt_embeds is not None:
333
+ embedding_data["audio_prompt_embeds"] = audio_prompt_embeds[0].cpu().contiguous()
334
+
335
+ output_file = output_path / output_rel_path
336
+ torch.save(embedding_data, output_file)
337
+
338
+ progress.advance(task)
339
+
340
+ logger.info(f"Processed {len(dataset):,} captions. Embeddings saved to {output_path}")
341
+
342
+
343
+ @app.command()
344
+ def main( # noqa: PLR0913
345
+ dataset_file: str = typer.Argument(
346
+ ...,
347
+ help="Path to metadata file (CSV/JSON/JSONL) containing captions and media paths",
348
+ ),
349
+ output_dir: str = typer.Option(
350
+ ...,
351
+ help="Output directory to save text embeddings",
352
+ ),
353
+ model_path: str = typer.Option(
354
+ ...,
355
+ help="Path to LTX-2 checkpoint (.safetensors file)",
356
+ ),
357
+ text_encoder_path: str = typer.Option(
358
+ ...,
359
+ help="Path to Gemma text encoder directory",
360
+ ),
361
+ caption_column: str = typer.Option(
362
+ default="caption",
363
+ help="Column name containing captions in the dataset JSON/JSONL/CSV file",
364
+ ),
365
+ media_column: str = typer.Option(
366
+ default="media_path",
367
+ help="Column name in the dataset JSON/JSONL/CSV file containing media paths "
368
+ "(used for output file naming and folder structure)",
369
+ ),
370
+ batch_size: int = typer.Option(
371
+ default=8,
372
+ help="Batch size for processing",
373
+ ),
374
+ device: str = typer.Option(
375
+ default="cuda",
376
+ help="Device to use for computation",
377
+ ),
378
+ lora_trigger: str | None = typer.Option(
379
+ default=None,
380
+ help="Optional trigger word to prepend to each caption (activates the LoRA during inference)",
381
+ ),
382
+ remove_llm_prefixes: bool = typer.Option(
383
+ default=False,
384
+ help="Remove common LLM-generated prefixes from captions",
385
+ ),
386
+ load_text_encoder_in_8bit: bool = typer.Option(
387
+ default=False,
388
+ help="Load the Gemma text encoder in 8-bit precision to save GPU memory (requires bitsandbytes)",
389
+ ),
390
+ ) -> None:
391
+ """Process text captions and save embeddings for video generation training.
392
+ This script processes captions from metadata files and saves text embeddings
393
+ that can be used for training video generation models. The output embeddings
394
+ will maintain the same folder structure and naming as the corresponding media files.
395
+ Note: This script is designed for LTX-2 models which use the Gemma text encoder.
396
+ Examples:
397
+ # Process captions with LTX-2 model
398
+ python scripts/process_captions.py dataset.json --output-dir ./embeddings \\
399
+ --model-path /path/to/ltx2_checkpoint.safetensors \\
400
+ --text-encoder-path /path/to/gemma
401
+ # Add a trigger word for LoRA training
402
+ python scripts/process_captions.py dataset.json --output-dir ./embeddings \\
403
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
404
+ --lora-trigger "mytoken"
405
+ # Remove LLM-generated prefixes from captions
406
+ python scripts/process_captions.py dataset.json --output-dir ./embeddings \\
407
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
408
+ --remove-llm-prefixes
409
+ """
410
+
411
+ # Validate dataset file
412
+ if not Path(dataset_file).is_file():
413
+ raise typer.BadParameter(f"Dataset file not found: {dataset_file}")
414
+
415
+ if lora_trigger:
416
+ logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions')
417
+
418
+ # Process embeddings
419
+ compute_captions_embeddings(
420
+ dataset_file=dataset_file,
421
+ output_dir=output_dir,
422
+ model_path=model_path,
423
+ text_encoder_path=text_encoder_path,
424
+ caption_column=caption_column,
425
+ media_column=media_column,
426
+ lora_trigger=lora_trigger,
427
+ remove_llm_prefixes=remove_llm_prefixes,
428
+ batch_size=batch_size,
429
+ device=device,
430
+ load_in_8bit=load_text_encoder_in_8bit,
431
+ )
432
+
433
+
434
+ if __name__ == "__main__":
435
+ app()
packages/ltx-trainer/scripts/process_dataset.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Preprocess a video dataset by computing video clips latents and text captions embeddings.
5
+ This script provides a command-line interface for preprocessing video datasets by computing
6
+ latent representations of video clips and text embeddings of their captions. The preprocessed
7
+ data can be used to accelerate training of video generation models and to save GPU memory.
8
+ Basic usage:
9
+ python scripts/process_dataset.py /path/to/dataset.json --resolution-buckets 768x768x49 \
10
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma
11
+ The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.
12
+ """
13
+
14
+ from pathlib import Path
15
+
16
+ import typer
17
+ from decode_latents import LatentsDecoder
18
+ from process_captions import compute_captions_embeddings
19
+ from process_videos import compute_latents, compute_scaled_resolution_buckets, parse_resolution_buckets
20
+ from rich.console import Console
21
+
22
+ from ltx_trainer import logger
23
+ from ltx_trainer.gpu_utils import free_gpu_memory_context
24
+
25
+ console = Console()
26
+
27
+ app = typer.Typer(
28
+ pretty_exceptions_enable=False,
29
+ no_args_is_help=True,
30
+ help="Preprocess a video dataset by computing video clips latents and text captions embeddings. "
31
+ "The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.",
32
+ )
33
+
34
+
35
+ def preprocess_dataset( # noqa: PLR0913
36
+ dataset_file: str,
37
+ caption_column: str,
38
+ video_column: str,
39
+ resolution_buckets: list[tuple[int, int, int]],
40
+ batch_size: int,
41
+ output_dir: str | None,
42
+ lora_trigger: str | None,
43
+ vae_tiling: bool,
44
+ decode: bool,
45
+ model_path: str,
46
+ text_encoder_path: str,
47
+ device: str,
48
+ remove_llm_prefixes: bool = False,
49
+ reference_column: str | None = None,
50
+ reference_downscale_factor: int = 1,
51
+ with_audio: bool = False,
52
+ load_text_encoder_in_8bit: bool = False,
53
+ ) -> None:
54
+ """Run the preprocessing pipeline with the given arguments."""
55
+ # Validate dataset file
56
+ _validate_dataset_file(dataset_file)
57
+
58
+ # Set up output directories
59
+ output_base = Path(output_dir) if output_dir else Path(dataset_file).parent / ".precomputed"
60
+ conditions_dir = output_base / "conditions"
61
+ latents_dir = output_base / "latents"
62
+
63
+ if lora_trigger:
64
+ logger.info(f'LoRA trigger word "{lora_trigger}" will be prepended to all captions')
65
+
66
+ with free_gpu_memory_context():
67
+ # Process captions using the dedicated function
68
+ compute_captions_embeddings(
69
+ dataset_file=dataset_file,
70
+ output_dir=str(conditions_dir),
71
+ model_path=model_path,
72
+ text_encoder_path=text_encoder_path,
73
+ caption_column=caption_column,
74
+ media_column=video_column,
75
+ lora_trigger=lora_trigger,
76
+ remove_llm_prefixes=remove_llm_prefixes,
77
+ batch_size=batch_size,
78
+ device=device,
79
+ load_in_8bit=load_text_encoder_in_8bit,
80
+ )
81
+
82
+ # Process videos using the dedicated function
83
+ audio_latents_dir = None
84
+ if with_audio:
85
+ logger.info("Audio preprocessing enabled - will extract and encode audio from videos")
86
+ audio_latents_dir = output_base / "audio_latents"
87
+
88
+ with free_gpu_memory_context():
89
+ compute_latents(
90
+ dataset_file=dataset_file,
91
+ video_column=video_column,
92
+ resolution_buckets=resolution_buckets,
93
+ output_dir=str(latents_dir),
94
+ model_path=model_path,
95
+ batch_size=batch_size,
96
+ device=device,
97
+ vae_tiling=vae_tiling,
98
+ with_audio=with_audio,
99
+ audio_output_dir=str(audio_latents_dir) if audio_latents_dir else None,
100
+ )
101
+
102
+ # Process reference videos if reference_column is provided
103
+ if reference_column:
104
+ # Validate: scaled references with multiple buckets can cause ambiguous bucket matching
105
+ if reference_downscale_factor > 1 and len(resolution_buckets) > 1:
106
+ raise ValueError(
107
+ "When using --reference-downscale-factor > 1, only a single resolution bucket is supported. "
108
+ "Using multiple buckets with scaled references can cause ambiguous bucket matching "
109
+ "(e.g., a 512x256 reference could match either the scaled-down 1024x512 bucket or the 512x256 "
110
+ "bucket). Please use a single resolution bucket or set --reference-downscale-factor to 1."
111
+ )
112
+
113
+ # Calculate and validate scaled resolution buckets for reference videos
114
+ reference_buckets = compute_scaled_resolution_buckets(resolution_buckets, reference_downscale_factor)
115
+
116
+ if reference_downscale_factor > 1:
117
+ logger.info(
118
+ f"Processing reference videos for IC-LoRA training at 1/{reference_downscale_factor} resolution..."
119
+ )
120
+ logger.info(f"Reference resolution buckets: {reference_buckets}")
121
+ else:
122
+ logger.info("Processing reference videos for IC-LoRA training...")
123
+
124
+ reference_latents_dir = output_base / "reference_latents"
125
+
126
+ compute_latents(
127
+ dataset_file=dataset_file,
128
+ main_media_column=video_column,
129
+ video_column=reference_column,
130
+ resolution_buckets=reference_buckets,
131
+ output_dir=str(reference_latents_dir),
132
+ model_path=model_path,
133
+ batch_size=batch_size,
134
+ device=device,
135
+ vae_tiling=vae_tiling,
136
+ )
137
+
138
+ # Handle decoding if requested (for verification)
139
+ if decode:
140
+ logger.info("Decoding latents for verification...")
141
+
142
+ decoder = LatentsDecoder(
143
+ model_path=model_path,
144
+ device=device,
145
+ vae_tiling=vae_tiling,
146
+ with_audio=with_audio,
147
+ )
148
+ decoder.decode(latents_dir, output_base / "decoded_videos")
149
+
150
+ # Also decode reference videos if they exist
151
+ if reference_column:
152
+ reference_latents_dir = output_base / "reference_latents"
153
+ if reference_latents_dir.exists():
154
+ logger.info("Decoding reference videos...")
155
+ decoder.decode(reference_latents_dir, output_base / "decoded_reference_videos")
156
+
157
+ # Decode audio latents if they exist
158
+ if with_audio and audio_latents_dir and audio_latents_dir.exists():
159
+ logger.info("Decoding audio latents...")
160
+ decoder.decode_audio(audio_latents_dir, output_base / "decoded_audio")
161
+
162
+ # Print summary
163
+ logger.info(f"Dataset preprocessing complete! Results saved to {output_base}")
164
+ if reference_column:
165
+ logger.info("Reference videos processed and saved to reference_latents/ directory for IC-LoRA training")
166
+ if with_audio:
167
+ logger.info("Audio latents saved to audio_latents/ directory for audio-video training")
168
+
169
+
170
+ def _validate_dataset_file(dataset_path: str) -> None:
171
+ """Validate that the dataset file exists and has the correct format."""
172
+ dataset_file = Path(dataset_path)
173
+
174
+ if not dataset_file.exists():
175
+ raise FileNotFoundError(f"Dataset file does not exist: {dataset_file}")
176
+
177
+ if not dataset_file.is_file():
178
+ raise ValueError(f"Dataset path must be a file, not a directory: {dataset_file}")
179
+
180
+ if dataset_file.suffix.lower() not in [".csv", ".json", ".jsonl"]:
181
+ raise ValueError(f"Dataset file must be CSV, JSON, or JSONL format: {dataset_file}")
182
+
183
+
184
+ @app.command()
185
+ def main( # noqa: PLR0913
186
+ dataset_path: str = typer.Argument(
187
+ ...,
188
+ help="Path to metadata file (CSV/JSON/JSONL) containing captions and video paths",
189
+ ),
190
+ resolution_buckets: str = typer.Option(
191
+ ...,
192
+ help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")',
193
+ ),
194
+ model_path: str = typer.Option(
195
+ ...,
196
+ help="Path to LTX-2 checkpoint (.safetensors file)",
197
+ ),
198
+ text_encoder_path: str = typer.Option(
199
+ ...,
200
+ help="Path to Gemma text encoder directory",
201
+ ),
202
+ caption_column: str = typer.Option(
203
+ default="caption",
204
+ help="Column name containing captions in the dataset JSON/JSONL/CSV file",
205
+ ),
206
+ video_column: str = typer.Option(
207
+ default="media_path",
208
+ help="Column name containing video paths in the dataset JSON/JSONL/CSV file",
209
+ ),
210
+ batch_size: int = typer.Option(
211
+ default=1,
212
+ help="Batch size for preprocessing",
213
+ ),
214
+ device: str = typer.Option(
215
+ default="cuda",
216
+ help="Device to use for computation",
217
+ ),
218
+ vae_tiling: bool = typer.Option(
219
+ default=False,
220
+ help="Enable VAE tiling for larger video resolutions",
221
+ ),
222
+ output_dir: str | None = typer.Option(
223
+ default=None,
224
+ help="Output directory (defaults to .precomputed in dataset directory)",
225
+ ),
226
+ lora_trigger: str | None = typer.Option(
227
+ default=None,
228
+ help="Optional trigger word to prepend to each caption (activates the LoRA during inference)",
229
+ ),
230
+ decode: bool = typer.Option(
231
+ default=False,
232
+ help="Decode and save latents after encoding (videos and audio) for verification",
233
+ ),
234
+ remove_llm_prefixes: bool = typer.Option(
235
+ default=False,
236
+ help="Remove LLM prefixes from captions",
237
+ ),
238
+ reference_column: str | None = typer.Option(
239
+ default=None,
240
+ help="Column name containing reference video paths (for video-to-video training)",
241
+ ),
242
+ with_audio: bool = typer.Option(
243
+ default=False,
244
+ help="Extract and encode audio from video files",
245
+ ),
246
+ load_text_encoder_in_8bit: bool = typer.Option(
247
+ default=False,
248
+ help="Load the Gemma text encoder in 8-bit precision to save GPU memory (requires bitsandbytes)",
249
+ ),
250
+ reference_downscale_factor: int = typer.Option(
251
+ default=1,
252
+ help="Downscale factor for reference video resolution. When > 1, reference videos are processed at "
253
+ "1/n resolution (e.g., 2 means half resolution). Used for efficient IC-LoRA training.",
254
+ ),
255
+ ) -> None:
256
+ """Preprocess a video dataset by computing and saving latents and text embeddings.
257
+ The dataset must be a CSV, JSON, or JSONL file with columns for captions and video paths.
258
+ This script is designed for LTX-2 models which use the Gemma text encoder.
259
+ Examples:
260
+ # Process a dataset with LTX-2 model
261
+ python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
262
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma
263
+ # Process dataset with custom column names
264
+ python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
265
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
266
+ --caption-column "text" --video-column "video_path"
267
+ # Process dataset with reference videos for IC-LoRA training
268
+ python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
269
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
270
+ --reference-column "reference_path"
271
+ # Process dataset with scaled reference videos (half resolution) for efficient IC-LoRA
272
+ python scripts/process_dataset.py dataset.json --resolution-buckets 768x768x25 \\
273
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
274
+ --reference-column "reference_path" --reference-downscale-factor 2
275
+ # Process dataset with audio for audio-video training
276
+ python scripts/process_dataset.py dataset.json --resolution-buckets 768x512x97 \\
277
+ --model-path /path/to/ltx2.safetensors --text-encoder-path /path/to/gemma \\
278
+ --with-audio
279
+ """
280
+ parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets)
281
+
282
+ if len(parsed_resolution_buckets) > 1:
283
+ logger.warning(
284
+ "Using multiple resolution buckets. "
285
+ "When training with multiple resolution buckets, you must use a batch size of 1."
286
+ )
287
+
288
+ # Validate reference_downscale_factor
289
+ if reference_downscale_factor < 1:
290
+ raise typer.BadParameter("--reference-downscale-factor must be >= 1")
291
+
292
+ if reference_downscale_factor > 1 and not reference_column:
293
+ logger.warning("--reference-downscale-factor specified but no --reference-column provided. Ignoring.")
294
+
295
+ preprocess_dataset(
296
+ dataset_file=dataset_path,
297
+ caption_column=caption_column,
298
+ video_column=video_column,
299
+ resolution_buckets=parsed_resolution_buckets,
300
+ batch_size=batch_size,
301
+ output_dir=output_dir,
302
+ lora_trigger=lora_trigger,
303
+ vae_tiling=vae_tiling,
304
+ decode=decode,
305
+ model_path=model_path,
306
+ text_encoder_path=text_encoder_path,
307
+ device=device,
308
+ remove_llm_prefixes=remove_llm_prefixes,
309
+ reference_column=reference_column,
310
+ reference_downscale_factor=reference_downscale_factor,
311
+ with_audio=with_audio,
312
+ load_text_encoder_in_8bit=load_text_encoder_in_8bit,
313
+ )
314
+
315
+
316
+ if __name__ == "__main__":
317
+ app()
packages/ltx-trainer/scripts/process_videos.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Compute latent representations for video generation training.
5
+ This module provides functionality for processing video and image files, including:
6
+ - Loading videos/images from various file formats (CSV, JSON, JSONL)
7
+ - Resizing, cropping, and transforming media
8
+ - MediaDataset for video-only preprocessing workflows
9
+ - BucketSampler for grouping videos by resolution
10
+ Can be used as a standalone script:
11
+ python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \
12
+ --output-dir /path/to/output --model-source /path/to/ltx2.safetensors
13
+ """
14
+
15
+ import json
16
+ import math
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ import torch
23
+ import torchaudio
24
+ import typer
25
+ from pillow_heif import register_heif_opener
26
+ from rich.console import Console
27
+ from rich.progress import (
28
+ BarColumn,
29
+ MofNCompleteColumn,
30
+ Progress,
31
+ SpinnerColumn,
32
+ TaskProgressColumn,
33
+ TextColumn,
34
+ TimeElapsedColumn,
35
+ TimeRemainingColumn,
36
+ )
37
+ from torch.utils.data import DataLoader, Dataset
38
+ from torchvision import transforms
39
+ from torchvision.transforms import InterpolationMode
40
+ from torchvision.transforms.functional import crop, resize, to_tensor
41
+ from transformers.utils.logging import disable_progress_bar
42
+
43
+ from ltx_core.model.audio_vae import AudioProcessor
44
+ from ltx_core.types import Audio
45
+ from ltx_trainer import logger
46
+ from ltx_trainer.model_loader import load_audio_vae_encoder, load_video_vae_encoder
47
+ from ltx_trainer.utils import open_image_as_srgb
48
+ from ltx_trainer.video_utils import get_video_frame_count, read_video
49
+
50
+ disable_progress_bar()
51
+
52
+ # Register HEIF/HEIC support
53
+ register_heif_opener()
54
+
55
+ # Constants for validation
56
+ VAE_SPATIAL_FACTOR = 32
57
+ VAE_TEMPORAL_FACTOR = 8
58
+
59
+ # Audio constants
60
+ AUDIO_LATENT_CHANNELS = 8
61
+ AUDIO_FREQUENCY_BINS = 16
62
+
63
+ DEFAULT_TILE_SIZE = 512 # Spatial tile size in pixels (must be ≥64 and divisible by 32)
64
+ DEFAULT_TILE_OVERLAP = 128 # Spatial tile overlap in pixels (must be divisible by 32)
65
+
66
+ app = typer.Typer(
67
+ pretty_exceptions_enable=False,
68
+ no_args_is_help=True,
69
+ help="Process videos/images and save latent representations for video generation training.",
70
+ )
71
+
72
+
73
+ class MediaDataset(Dataset):
74
+ """
75
+ Dataset for processing video and image files.
76
+ This dataset is designed for media preprocessing workflows where you need to:
77
+ - Load and preprocess videos/images
78
+ - Apply resizing and cropping transformations
79
+ - Handle different resolution buckets
80
+ - Filter out invalid media files
81
+ - Optionally extract audio from video files
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ dataset_file: str | Path,
87
+ main_media_column: str,
88
+ video_column: str,
89
+ resolution_buckets: list[tuple[int, int, int]],
90
+ reshape_mode: str = "center",
91
+ with_audio: bool = False,
92
+ ) -> None:
93
+ """
94
+ Initialize the media dataset.
95
+ Args:
96
+ dataset_file: Path to CSV/JSON/JSONL metadata file
97
+ video_column: Column name for video paths in the metadata file
98
+ resolution_buckets: List of (frames, height, width) tuples
99
+ reshape_mode: How to crop videos ("center", "random")
100
+ with_audio: Whether to extract audio from video files
101
+ """
102
+ super().__init__()
103
+
104
+ self.dataset_file = Path(dataset_file)
105
+ self.main_media_column = main_media_column
106
+ self.resolution_buckets = resolution_buckets
107
+ self.reshape_mode = reshape_mode
108
+ self.with_audio = with_audio
109
+
110
+ # First load main media paths
111
+ self.main_media_paths = self._load_video_paths(main_media_column)
112
+
113
+ # Then load reference video paths
114
+ self.video_paths = self._load_video_paths(video_column)
115
+
116
+ # Filter out videos with insufficient frames
117
+ self._filter_valid_videos()
118
+
119
+ self.max_target_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
120
+
121
+ # Set up video transforms
122
+ self.transforms = transforms.Compose(
123
+ [
124
+ transforms.Lambda(lambda x: x.clamp_(0, 1)),
125
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
126
+ ]
127
+ )
128
+
129
+ def __len__(self) -> int:
130
+ return len(self.video_paths)
131
+
132
+ def __getitem__(self, index: int) -> dict[str, Any]:
133
+ """Get a single video/image with metadata, and optionally audio."""
134
+ if isinstance(index, list):
135
+ # Special case for BucketSampler - return cached data
136
+ return index
137
+
138
+ video_path: Path = self.video_paths[index]
139
+
140
+ # Compute relative path of the video
141
+ data_root = self.dataset_file.parent
142
+ relative_path = str(video_path.relative_to(data_root))
143
+ media_relative_path = str(self.main_media_paths[index].relative_to(data_root))
144
+
145
+ if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
146
+ media_tensor = self._preprocess_image(video_path)
147
+ fps = 1.0
148
+ audio_data = None # Images don't have audio
149
+ else:
150
+ media_tensor, fps = self._preprocess_video(video_path)
151
+
152
+ # Extract audio if enabled
153
+ if self.with_audio:
154
+ # Calculate target duration from the processed video frames
155
+ # This ensures audio is trimmed to match the exact video duration
156
+ # media_tensor is [C, F, H, W] so shape[1] is num_frames
157
+ target_duration = media_tensor.shape[1] / fps
158
+ audio_data = self._extract_audio(video_path, target_duration)
159
+ else:
160
+ audio_data = None
161
+
162
+ # media_tensor is [C, F, H, W] format for VAE compatibility
163
+ _, num_frames, height, width = media_tensor.shape
164
+
165
+ result = {
166
+ "video": media_tensor,
167
+ "relative_path": relative_path,
168
+ "main_media_relative_path": media_relative_path,
169
+ "video_metadata": {
170
+ "num_frames": num_frames,
171
+ "height": height,
172
+ "width": width,
173
+ "fps": fps,
174
+ },
175
+ }
176
+
177
+ # Add audio data if available
178
+ if audio_data is not None:
179
+ result["audio"] = audio_data
180
+
181
+ return result
182
+
183
+ @staticmethod
184
+ def _extract_audio(video_path: Path, target_duration: float) -> dict[str, torch.Tensor | int] | None:
185
+ """Extract audio track from a video file, trimmed to match video duration."""
186
+ try:
187
+ # torchaudio can extract audio from video files directly
188
+ # waveform shape: [channels, samples]
189
+ waveform, sample_rate = torchaudio.load(str(video_path))
190
+
191
+ # Trim or pad to target duration
192
+ target_samples = int(target_duration * sample_rate)
193
+ current_samples = waveform.shape[-1]
194
+
195
+ if current_samples > target_samples:
196
+ # Trim to target duration
197
+ waveform = waveform[..., :target_samples]
198
+ elif current_samples < target_samples:
199
+ # Pad with zeros to target duration
200
+ padding = target_samples - current_samples
201
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
202
+ logger.warning(f"Padded audio to {target_duration:.2f} seconds for {video_path}")
203
+
204
+ return {"waveform": waveform, "sample_rate": sample_rate}
205
+
206
+ except Exception as e:
207
+ logger.debug(f"Could not extract audio from {video_path}: {e}")
208
+ return None
209
+
210
+ def _load_video_paths(self, column: str) -> list[Path]:
211
+ """Load video paths from the specified data source."""
212
+ if self.dataset_file.suffix == ".csv":
213
+ return self._load_video_paths_from_csv(column)
214
+ elif self.dataset_file.suffix == ".json":
215
+ return self._load_video_paths_from_json(column)
216
+ elif self.dataset_file.suffix == ".jsonl":
217
+ return self._load_video_paths_from_jsonl(column)
218
+ else:
219
+ raise ValueError("Expected `dataset_file` to be a path to a CSV, JSON, or JSONL file.")
220
+
221
+ def _load_video_paths_from_csv(self, column: str) -> list[Path]:
222
+ """Load video paths from a CSV file."""
223
+ df = pd.read_csv(self.dataset_file)
224
+ if column not in df.columns:
225
+ raise ValueError(f"Column '{column}' not found in CSV file")
226
+
227
+ data_root = self.dataset_file.parent
228
+ video_paths = [data_root / Path(line.strip()) for line in df[column].tolist()]
229
+
230
+ # Validate that all paths exist
231
+ invalid_paths = [path for path in video_paths if not path.is_file()]
232
+ if invalid_paths:
233
+ raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}")
234
+
235
+ return video_paths
236
+
237
+ def _load_video_paths_from_json(self, column: str) -> list[Path]:
238
+ """Load video paths from a JSON file."""
239
+ with open(self.dataset_file, "r", encoding="utf-8") as file:
240
+ data = json.load(file)
241
+
242
+ if not isinstance(data, list):
243
+ raise ValueError("JSON file must contain a list of objects")
244
+
245
+ data_root = self.dataset_file.parent
246
+ video_paths = []
247
+ for entry in data:
248
+ if column not in entry:
249
+ raise ValueError(f"Key '{column}' not found in JSON entry")
250
+ video_paths.append(data_root / Path(entry[column].strip()))
251
+
252
+ # Validate that all paths exist
253
+ invalid_paths = [path for path in video_paths if not path.is_file()]
254
+ if invalid_paths:
255
+ raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}")
256
+
257
+ return video_paths
258
+
259
+ def _load_video_paths_from_jsonl(self, column: str) -> list[Path]:
260
+ """Load video paths from a JSONL file."""
261
+ data_root = self.dataset_file.parent
262
+ video_paths = []
263
+ with open(self.dataset_file, "r", encoding="utf-8") as file:
264
+ for line in file:
265
+ entry = json.loads(line)
266
+ if column not in entry:
267
+ raise ValueError(f"Key '{column}' not found in JSONL entry")
268
+ video_paths.append(data_root / Path(entry[column].strip()))
269
+
270
+ # Validate that all paths exist
271
+ invalid_paths = [path for path in video_paths if not path.is_file()]
272
+ if invalid_paths:
273
+ raise ValueError(f"Found {len(invalid_paths)} invalid video paths. First few: {invalid_paths[:5]}")
274
+
275
+ return video_paths
276
+
277
+ def _filter_valid_videos(self) -> None:
278
+ """Filter out videos with insufficient frames."""
279
+ original_length = len(self.video_paths)
280
+ valid_video_paths = []
281
+ valid_main_media_paths = []
282
+ min_frames_required = min(self.resolution_buckets, key=lambda x: x[0])[0]
283
+
284
+ for i, video_path in enumerate(self.video_paths):
285
+ if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
286
+ valid_video_paths.append(video_path)
287
+ valid_main_media_paths.append(self.main_media_paths[i])
288
+ continue
289
+
290
+ try:
291
+ frame_count = get_video_frame_count(video_path)
292
+
293
+ if frame_count >= min_frames_required:
294
+ valid_video_paths.append(video_path)
295
+ valid_main_media_paths.append(self.main_media_paths[i])
296
+ else:
297
+ logger.warning(
298
+ f"Skipping video at {video_path} - has {frame_count} frames, "
299
+ f"which is less than the minimum required frames ({min_frames_required})"
300
+ )
301
+ except Exception as e:
302
+ logger.warning(f"Failed to read video at {video_path}: {e!s}")
303
+
304
+ # Update both path lists to maintain synchronization
305
+ self.video_paths = valid_video_paths
306
+ self.main_media_paths = valid_main_media_paths
307
+
308
+ if len(self.video_paths) < original_length:
309
+ logger.warning(
310
+ f"Filtered out {original_length - len(self.video_paths)} videos with insufficient frames. "
311
+ f"Proceeding with {len(self.video_paths)} valid videos."
312
+ )
313
+
314
+ def _preprocess_image(self, path: Path) -> torch.Tensor:
315
+ """Preprocess a single image by resizing and applying transforms."""
316
+ image = open_image_as_srgb(path)
317
+ image = to_tensor(image)
318
+ image = image.unsqueeze(0) # Add frame dimension [1, C, H, W] for bucket selection
319
+
320
+ # Find nearest resolution bucket and resize
321
+ nearest_bucket = self._get_resolution_bucket_for_item(image)
322
+ _, target_height, target_width = nearest_bucket
323
+ image_resized = self._resize_and_crop(image, target_height, target_width)
324
+ # _resize_and_crop returns [C, H, W] for single-frame input (squeeze removes dim 0)
325
+
326
+ # Apply transforms
327
+ image = self.transforms(image_resized) # [C, H, W] -> [C, H, W]
328
+
329
+ # Add frame dimension in VAE format: [C, H, W] -> [C, 1, H, W]
330
+ image = image.unsqueeze(1)
331
+ return image
332
+
333
+ def _preprocess_video(self, path: Path) -> tuple[torch.Tensor, float]:
334
+ """Preprocess a video by loading, resizing, and applying transforms.
335
+ Returns:
336
+ Tuple of (video tensor in [C, F, H, W] format, fps)
337
+ """
338
+ # Load video frames up to max_target_frames
339
+ video, fps = read_video(path, max_frames=self.max_target_frames)
340
+
341
+ nearest_bucket = self._get_resolution_bucket_for_item(video)
342
+ target_num_frames, target_height, target_width = nearest_bucket
343
+ frames_resized = self._resize_and_crop(video, target_height, target_width)
344
+
345
+ # Trim video to target number of frames
346
+ frames_resized = frames_resized[:target_num_frames]
347
+
348
+ # Apply transforms to each frame and stack
349
+ video = torch.stack([self.transforms(frame) for frame in frames_resized], dim=0)
350
+
351
+ # Permute [F,C,H,W] -> [C,F,H,W] for VAE compatibility
352
+ # After DataLoader batching, this becomes [B,C,F,H,W] which VAE expects
353
+ video = video.permute(1, 0, 2, 3).contiguous()
354
+
355
+ return video, fps
356
+
357
+ def _get_resolution_bucket_for_item(self, media_tensor: torch.Tensor) -> tuple[int, int, int]:
358
+ """Get the nearest resolution bucket for the given media tensor."""
359
+ num_frames, _, height, width = media_tensor.shape
360
+
361
+ def distance(bucket: tuple[int, int, int]) -> tuple:
362
+ bucket_num_frames, bucket_height, bucket_width = bucket
363
+ # Lexicographic key:
364
+ # 1) minimize aspect-ratio diff (in log-scale, for invariance to shorter/longer ARs)
365
+ # 2) prefer buckets with more frames (by using negative)
366
+ # 3) prefer buckets with larger spatial area (by using negative)
367
+ return (
368
+ abs(math.log(width / height) - math.log(bucket_width / bucket_height)),
369
+ -bucket_num_frames,
370
+ -(bucket_height * bucket_width),
371
+ )
372
+
373
+ # Keep only buckets with <= available frames
374
+ relevant_buckets = [b for b in self.resolution_buckets if b[0] <= num_frames]
375
+ if not relevant_buckets:
376
+ raise ValueError(f"No resolution buckets have <= {num_frames} frames. Available: {self.resolution_buckets}")
377
+
378
+ # Find the bucket with the minimal distance (according to the function above) to the media item's shape.
379
+ nearest_bucket = min(relevant_buckets, key=distance)
380
+
381
+ return nearest_bucket
382
+
383
+ def _resize_and_crop(self, media_tensor: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor:
384
+ """Resize and crop tensor to target size."""
385
+ # Get current dimensions
386
+ current_height, current_width = media_tensor.shape[2], media_tensor.shape[3]
387
+
388
+ # Calculate aspect ratios to determine which dimension to resize first
389
+ current_aspect = current_width / current_height
390
+ target_aspect = target_width / target_height
391
+
392
+ # Resize while maintaining aspect ratio - scale to make the smaller dimension fit
393
+ if current_aspect > target_aspect:
394
+ # Current is wider than target, so scale by height
395
+ new_width = int(current_width * target_height / current_height)
396
+ media_tensor = resize(
397
+ media_tensor,
398
+ size=[target_height, new_width], # type: ignore
399
+ interpolation=InterpolationMode.BICUBIC,
400
+ )
401
+ else:
402
+ # Current is taller than target, so scale by width
403
+ new_height = int(current_height * target_width / current_width)
404
+ media_tensor = resize(
405
+ media_tensor,
406
+ size=[new_height, target_width],
407
+ interpolation=InterpolationMode.BICUBIC,
408
+ )
409
+
410
+ # Update dimensions after resize
411
+ current_height, current_width = media_tensor.shape[2], media_tensor.shape[3]
412
+ media_tensor = media_tensor.squeeze(0)
413
+
414
+ # Calculate how much we need to crop from each dimension
415
+ delta_h = current_height - target_height
416
+ delta_w = current_width - target_width
417
+
418
+ # Determine crop position based on reshape mode
419
+ if self.reshape_mode == "random":
420
+ # Random crop position
421
+ top = np.random.randint(0, delta_h + 1)
422
+ left = np.random.randint(0, delta_w + 1)
423
+ elif self.reshape_mode == "center":
424
+ # Center crop
425
+ top, left = delta_h // 2, delta_w // 2
426
+ else:
427
+ raise ValueError(f"Unsupported reshape mode: {self.reshape_mode}")
428
+
429
+ # Perform the final crop to exact target dimensions
430
+ media_tensor = crop(media_tensor, top=top, left=left, height=target_height, width=target_width)
431
+ return media_tensor
432
+
433
+
434
+ def compute_latents( # noqa: PLR0913, PLR0915
435
+ dataset_file: str | Path,
436
+ video_column: str,
437
+ resolution_buckets: list[tuple[int, int, int]],
438
+ output_dir: str,
439
+ model_path: str,
440
+ main_media_column: str | None = None,
441
+ reshape_mode: str = "center",
442
+ batch_size: int = 1,
443
+ device: str = "cuda",
444
+ vae_tiling: bool = False,
445
+ with_audio: bool = False,
446
+ audio_output_dir: str | None = None,
447
+ ) -> None:
448
+ """
449
+ Process videos and save latent representations.
450
+ Args:
451
+ dataset_file: Path to metadata file (CSV/JSON/JSONL) containing video paths
452
+ video_column: Column name for video paths in the metadata file
453
+ resolution_buckets: List of (frames, height, width) tuples
454
+ output_dir: Directory to save video latents
455
+ model_path: Path to LTX-2 checkpoint (.safetensors)
456
+ reshape_mode: How to crop videos ("center", "random")
457
+ main_media_column: Column name for main media paths (if different from video_column)
458
+ batch_size: Batch size for processing
459
+ device: Device to use for computation
460
+ vae_tiling: Whether to enable VAE tiling
461
+ with_audio: Whether to extract and encode audio from videos
462
+ audio_output_dir: Directory to save audio latents (required if with_audio=True)
463
+ """
464
+ # Validate audio parameters
465
+ if with_audio and audio_output_dir is None:
466
+ raise ValueError("audio_output_dir must be provided when with_audio=True")
467
+
468
+ console = Console()
469
+ torch_device = torch.device(device)
470
+
471
+ # Create dataset
472
+ dataset = MediaDataset(
473
+ dataset_file=dataset_file,
474
+ main_media_column=main_media_column or video_column,
475
+ video_column=video_column,
476
+ resolution_buckets=resolution_buckets,
477
+ reshape_mode=reshape_mode,
478
+ with_audio=with_audio,
479
+ )
480
+ logger.info(f"Loaded {len(dataset)} valid media files")
481
+
482
+ output_path = Path(output_dir)
483
+ output_path.mkdir(parents=True, exist_ok=True)
484
+
485
+ # Set up audio output directory if needed
486
+ audio_output_path = None
487
+ if with_audio:
488
+ audio_output_path = Path(audio_output_dir)
489
+ audio_output_path.mkdir(parents=True, exist_ok=True)
490
+
491
+ # Load video VAE encoder
492
+ with console.status(f"[bold]Loading video VAE encoder from [cyan]{model_path}[/]...", spinner="dots"):
493
+ vae = load_video_vae_encoder(model_path, device=torch_device, dtype=torch.bfloat16)
494
+
495
+ # Load audio VAE encoder and audio processor if needed
496
+ audio_vae_encoder = None
497
+ audio_processor = None
498
+ if with_audio:
499
+ with console.status(f"[bold]Loading audio VAE encoder from [cyan]{model_path}[/]...", spinner="dots"):
500
+ audio_vae_encoder = load_audio_vae_encoder(
501
+ checkpoint_path=model_path,
502
+ device=torch_device,
503
+ dtype=torch.float32, # Audio VAE needs float32 for quality. TODO: re-test with bfloat16.
504
+ )
505
+ # Create audio processor for waveform-to-spectrogram conversion
506
+ audio_processor = AudioProcessor(
507
+ target_sample_rate=audio_vae_encoder.sample_rate,
508
+ mel_bins=audio_vae_encoder.mel_bins,
509
+ mel_hop_length=audio_vae_encoder.mel_hop_length,
510
+ n_fft=audio_vae_encoder.n_fft,
511
+ ).to(torch_device)
512
+
513
+ # Create dataloader
514
+ # Note: batch_size=1 required when with_audio because audio extraction can fail for some videos,
515
+ # and the default collate function can't handle mixed None/dict values across a batch.
516
+ if with_audio and batch_size > 1:
517
+ logger.warning("Audio processing requires batch_size=1. Overriding batch_size to 1.")
518
+ batch_size = 1
519
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
520
+
521
+ # Track audio statistics
522
+ audio_success_count = 0
523
+ audio_skip_count = 0
524
+
525
+ # Process batches
526
+ with Progress(
527
+ SpinnerColumn(),
528
+ TextColumn("[progress.description]{task.description}"),
529
+ BarColumn(),
530
+ TaskProgressColumn(),
531
+ MofNCompleteColumn(),
532
+ TimeElapsedColumn(),
533
+ TimeRemainingColumn(),
534
+ console=console,
535
+ ) as progress:
536
+ task = progress.add_task("Processing videos", total=len(dataloader))
537
+
538
+ for batch in dataloader:
539
+ # Get video tensor - shape is [B, F, C, H, W] from DataLoader
540
+ video = batch["video"]
541
+
542
+ # Encode video
543
+ with torch.inference_mode():
544
+ video_latent_data = encode_video(vae=vae, video=video, use_tiling=vae_tiling)
545
+
546
+ # Save latents for each item in batch
547
+ for i in range(len(batch["relative_path"])):
548
+ output_rel_path = Path(batch["main_media_relative_path"][i]).with_suffix(".pt")
549
+ output_file = output_path / output_rel_path
550
+
551
+ # Create output directory maintaining structure
552
+ output_file.parent.mkdir(parents=True, exist_ok=True)
553
+
554
+ # Index into batch to get this item's latents
555
+ latent_data = {
556
+ "latents": video_latent_data["latents"][i].cpu().contiguous(), # [C, F', H', W']
557
+ "num_frames": video_latent_data["num_frames"],
558
+ "height": video_latent_data["height"],
559
+ "width": video_latent_data["width"],
560
+ "fps": batch["video_metadata"]["fps"][i].item(),
561
+ }
562
+
563
+ torch.save(latent_data, output_file)
564
+
565
+ # Process audio if enabled (audio is already extracted by the dataset)
566
+ if with_audio:
567
+ audio_batch = batch.get("audio")
568
+ if audio_batch is not None:
569
+ # Extract the i-th item from batched audio data
570
+ # DataLoader collates [channels, samples] -> [batch, channels, samples]
571
+ audio_data = Audio(
572
+ waveform=audio_batch["waveform"][i],
573
+ sampling_rate=audio_batch["sample_rate"][i].item(),
574
+ )
575
+
576
+ # Encode audio
577
+ with torch.inference_mode():
578
+ audio_latents = encode_audio(audio_vae_encoder, audio_processor, audio_data)
579
+
580
+ # Save audio latents
581
+ audio_output_file = audio_output_path / output_rel_path
582
+ audio_output_file.parent.mkdir(parents=True, exist_ok=True)
583
+
584
+ audio_save_data = {
585
+ "latents": audio_latents["latents"].cpu().contiguous(),
586
+ "num_time_steps": audio_latents["num_time_steps"],
587
+ "frequency_bins": audio_latents["frequency_bins"],
588
+ "duration": audio_latents["duration"],
589
+ }
590
+
591
+ torch.save(audio_save_data, audio_output_file)
592
+ audio_success_count += 1
593
+ else:
594
+ # Video has no audio track
595
+ audio_skip_count += 1
596
+
597
+ progress.advance(task)
598
+
599
+ # Log summary
600
+ logger.info(f"Processed {len(dataset)} videos. Latents saved to {output_path}")
601
+ if with_audio:
602
+ logger.info(
603
+ f"Audio processing: {audio_success_count} videos with audio, "
604
+ f"{audio_skip_count} videos without audio (skipped)"
605
+ )
606
+
607
+
608
+ def encode_video(
609
+ vae: torch.nn.Module,
610
+ video: torch.Tensor,
611
+ dtype: torch.dtype | None = None,
612
+ use_tiling: bool = False,
613
+ tile_size: int = DEFAULT_TILE_SIZE,
614
+ tile_overlap: int = DEFAULT_TILE_OVERLAP,
615
+ ) -> dict[str, torch.Tensor | int]:
616
+ """Encode video into non-patchified latent representation.
617
+ Args:
618
+ vae: Video VAE encoder model
619
+ video: Input tensor of shape [B, C, F, H, W] (batch, channels, frames, height, width)
620
+ This is the format expected by the VAE encoder.
621
+ dtype: Target dtype for output latents
622
+ use_tiling: Whether to use spatial tiling for memory efficiency
623
+ tile_size: Tile size in pixels (must be divisible by 32)
624
+ tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
625
+ Returns:
626
+ Dict containing non-patchified latents and shape information:
627
+ {
628
+ "latents": Tensor[B, C, F', H', W'], # Non-patchified format with batch dim
629
+ "num_frames": int, # Latent frame count
630
+ "height": int, # Latent height
631
+ "width": int, # Latent width
632
+ }
633
+ """
634
+ device = next(vae.parameters()).device
635
+ vae_dtype = next(vae.parameters()).dtype
636
+
637
+ # Add batch dimension if needed
638
+ if video.ndim == 4:
639
+ video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
640
+
641
+ video = video.to(device=device, dtype=vae_dtype)
642
+
643
+ # Choose encoding method based on tiling flag
644
+ if use_tiling:
645
+ latents = tiled_encode_video(
646
+ vae=vae,
647
+ video=video,
648
+ tile_size=tile_size,
649
+ tile_overlap=tile_overlap,
650
+ )
651
+ else:
652
+ # Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
653
+ latents = vae(video)
654
+
655
+ if dtype is not None:
656
+ latents = latents.to(dtype=dtype)
657
+
658
+ _, _, num_frames, height, width = latents.shape
659
+
660
+ return {
661
+ "latents": latents, # [B, C, F', H', W']
662
+ "num_frames": num_frames,
663
+ "height": height,
664
+ "width": width,
665
+ }
666
+
667
+
668
+ def tiled_encode_video( # noqa: PLR0912, PLR0915
669
+ vae: torch.nn.Module,
670
+ video: torch.Tensor,
671
+ tile_size: int = DEFAULT_TILE_SIZE,
672
+ tile_overlap: int = DEFAULT_TILE_OVERLAP,
673
+ ) -> torch.Tensor:
674
+ """Encode video using spatial tiling for memory efficiency.
675
+ Splits the video into overlapping spatial tiles, encodes each tile separately,
676
+ and blends the results using linear feathering in the overlap regions.
677
+ Args:
678
+ vae: Video VAE encoder model
679
+ video: Input tensor of shape [B, C, F, H, W]
680
+ tile_size: Tile size in pixels (must be divisible by 32)
681
+ tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
682
+ Returns:
683
+ Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]
684
+ """
685
+ batch, _channels, frames, height, width = video.shape
686
+ device = video.device
687
+ dtype = video.dtype
688
+
689
+ # Validate tile parameters
690
+ if tile_size % VAE_SPATIAL_FACTOR != 0:
691
+ raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}")
692
+ if tile_overlap % VAE_SPATIAL_FACTOR != 0:
693
+ raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
694
+ if tile_overlap >= tile_size:
695
+ raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
696
+
697
+ # If video fits in a single tile, use regular encoding
698
+ if height <= tile_size and width <= tile_size:
699
+ return vae(video)
700
+
701
+ # Calculate output dimensions
702
+ # VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8
703
+ output_height = height // VAE_SPATIAL_FACTOR
704
+ output_width = width // VAE_SPATIAL_FACTOR
705
+ output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR
706
+
707
+ # Latent channels (128 for LTX-2)
708
+ # Get from a small test encode or assume 128
709
+ latent_channels = 128
710
+
711
+ # Initialize output and weight tensors
712
+ output = torch.zeros(
713
+ (batch, latent_channels, output_frames, output_height, output_width),
714
+ device=device,
715
+ dtype=dtype,
716
+ )
717
+ weights = torch.zeros(
718
+ (batch, 1, output_frames, output_height, output_width),
719
+ device=device,
720
+ dtype=dtype,
721
+ )
722
+
723
+ # Calculate tile positions with overlap
724
+ # Step size is tile_size - tile_overlap
725
+ step_h = tile_size - tile_overlap
726
+ step_w = tile_size - tile_overlap
727
+
728
+ h_positions = list(range(0, max(1, height - tile_overlap), step_h))
729
+ w_positions = list(range(0, max(1, width - tile_overlap), step_w))
730
+
731
+ # Ensure last tile covers the edge
732
+ if h_positions[-1] + tile_size < height:
733
+ h_positions.append(height - tile_size)
734
+ if w_positions[-1] + tile_size < width:
735
+ w_positions.append(width - tile_size)
736
+
737
+ # Remove duplicates and sort
738
+ h_positions = sorted(set(h_positions))
739
+ w_positions = sorted(set(w_positions))
740
+
741
+ # Overlap in latent space
742
+ overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
743
+ overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR
744
+
745
+ # Process each tile
746
+ for h_pos in h_positions:
747
+ for w_pos in w_positions:
748
+ # Calculate tile boundaries in input space
749
+ h_start = max(0, h_pos)
750
+ w_start = max(0, w_pos)
751
+ h_end = min(h_start + tile_size, height)
752
+ w_end = min(w_start + tile_size, width)
753
+
754
+ # Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
755
+ tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
756
+ tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
757
+
758
+ if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
759
+ continue
760
+
761
+ # Adjust end positions
762
+ h_end = h_start + tile_h
763
+ w_end = w_start + tile_w
764
+
765
+ # Extract tile
766
+ tile = video[:, :, :, h_start:h_end, w_start:w_end]
767
+
768
+ # Encode tile
769
+ encoded_tile = vae(tile)
770
+
771
+ # Get actual encoded dimensions
772
+ _, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape
773
+
774
+ # Calculate output positions
775
+ out_h_start = h_start // VAE_SPATIAL_FACTOR
776
+ out_w_start = w_start // VAE_SPATIAL_FACTOR
777
+ out_h_end = min(out_h_start + tile_out_height, output_height)
778
+ out_w_end = min(out_w_start + tile_out_width, output_width)
779
+
780
+ # Trim encoded tile if necessary
781
+ actual_tile_h = out_h_end - out_h_start
782
+ actual_tile_w = out_w_end - out_w_start
783
+ encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]
784
+
785
+ # Create blending mask with linear feathering at edges
786
+ mask = torch.ones(
787
+ (1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
788
+ device=device,
789
+ dtype=dtype,
790
+ )
791
+
792
+ # Apply feathering at edges (linear blend in overlap regions)
793
+ # Left edge
794
+ if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
795
+ fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
796
+ mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
797
+
798
+ # Right edge (bottom in height dimension)
799
+ if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
800
+ fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
801
+ mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
802
+
803
+ # Top edge (left in width dimension)
804
+ if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
805
+ fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
806
+ mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
807
+
808
+ # Bottom edge (right in width dimension)
809
+ if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
810
+ fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
811
+ mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
812
+
813
+ # Accumulate weighted results
814
+ output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
815
+ weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
816
+
817
+ # Normalize by weights (avoid division by zero)
818
+ output = output / (weights + 1e-8)
819
+
820
+ return output
821
+
822
+
823
+ def encode_audio(
824
+ audio_vae_encoder: torch.nn.Module,
825
+ audio_processor: torch.nn.Module,
826
+ audio: Audio,
827
+ ) -> dict[str, torch.Tensor | int | float]:
828
+ """Encode audio waveform into latent representation.
829
+ Args:
830
+ audio_vae_encoder: Audio VAE encoder model from ltx-core
831
+ audio_processor: AudioProcessor for waveform-to-spectrogram conversion
832
+ audio: Audio container with waveform tensor and sampling rate.
833
+ Returns:
834
+ Dict containing audio latents and shape information:
835
+ {
836
+ "latents": Tensor[C, T, F], # Non-patchified format
837
+ "num_time_steps": int,
838
+ "frequency_bins": int,
839
+ "duration": float,
840
+ }
841
+ """
842
+ device = next(audio_vae_encoder.parameters()).device
843
+ dtype = next(audio_vae_encoder.parameters()).dtype
844
+
845
+ waveform = audio.waveform.to(device=device, dtype=dtype)
846
+
847
+ # Add batch dimension if needed: [channels, samples] -> [batch, channels, samples]
848
+ if waveform.dim() == 2:
849
+ waveform = waveform.unsqueeze(0)
850
+
851
+ # Calculate duration
852
+ duration = waveform.shape[-1] / audio.sampling_rate
853
+
854
+ # Convert waveform to mel spectrogram using AudioProcessor
855
+ mel_spectrogram = audio_processor.waveform_to_mel(Audio(waveform=waveform, sampling_rate=audio.sampling_rate))
856
+ mel_spectrogram = mel_spectrogram.to(dtype=dtype)
857
+
858
+ # Encode mel spectrogram to latents
859
+ latents = audio_vae_encoder(mel_spectrogram)
860
+
861
+ # latents shape: [batch, channels, time, freq] = [1, 8, T, 16]
862
+ _, _channels, time_steps, freq_bins = latents.shape
863
+
864
+ return {
865
+ "latents": latents.squeeze(0), # [C, T, F] - remove batch dim
866
+ "num_time_steps": time_steps,
867
+ "frequency_bins": freq_bins,
868
+ "duration": duration,
869
+ }
870
+
871
+
872
+ def parse_resolution_buckets(resolution_buckets_str: str) -> list[tuple[int, int, int]]:
873
+ """Parse resolution buckets from string format to list of tuples (frames, height, width)"""
874
+ resolution_buckets = []
875
+ for bucket_str in resolution_buckets_str.split(";"):
876
+ w, h, f = map(int, bucket_str.split("x"))
877
+
878
+ if w % VAE_SPATIAL_FACTOR != 0 or h % VAE_SPATIAL_FACTOR != 0:
879
+ raise typer.BadParameter(
880
+ f"Width and height must be multiples of {VAE_SPATIAL_FACTOR}, got {w}x{h}",
881
+ param_hint="resolution-buckets",
882
+ )
883
+
884
+ if f % VAE_TEMPORAL_FACTOR != 1:
885
+ raise typer.BadParameter(
886
+ f"Number of frames must be a multiple of {VAE_TEMPORAL_FACTOR} plus 1, got {f}",
887
+ param_hint="resolution-buckets",
888
+ )
889
+
890
+ resolution_buckets.append((f, h, w))
891
+ return resolution_buckets
892
+
893
+
894
+ def compute_scaled_resolution_buckets(
895
+ resolution_buckets: list[tuple[int, int, int]],
896
+ scale_factor: int,
897
+ ) -> list[tuple[int, int, int]]:
898
+ """Compute scaled resolution buckets and validate the results."""
899
+ if scale_factor == 1:
900
+ return resolution_buckets
901
+
902
+ scaled_buckets = []
903
+ for frames, height, width in resolution_buckets:
904
+ # Validate that scale factor evenly divides the dimensions
905
+ if height % scale_factor != 0:
906
+ raise ValueError(
907
+ f"Height {height} is not evenly divisible by scale factor {scale_factor}. "
908
+ f"Choose a scale factor that divides {height} evenly."
909
+ )
910
+ if width % scale_factor != 0:
911
+ raise ValueError(
912
+ f"Width {width} is not evenly divisible by scale factor {scale_factor}. "
913
+ f"Choose a scale factor that divides {width} evenly."
914
+ )
915
+
916
+ scaled_height = height // scale_factor
917
+ scaled_width = width // scale_factor
918
+
919
+ # Validate scaled dimensions are divisible by VAE spatial factor
920
+ if scaled_height % VAE_SPATIAL_FACTOR != 0:
921
+ raise ValueError(
922
+ f"Scaled height {scaled_height} (from {height} / {scale_factor}) "
923
+ f"is not divisible by {VAE_SPATIAL_FACTOR}. "
924
+ f"Choose a different scale factor or adjust your resolution buckets."
925
+ )
926
+ if scaled_width % VAE_SPATIAL_FACTOR != 0:
927
+ raise ValueError(
928
+ f"Scaled width {scaled_width} (from {width} / {scale_factor}) "
929
+ f"is not divisible by {VAE_SPATIAL_FACTOR}. "
930
+ f"Choose a different scale factor or adjust your resolution buckets."
931
+ )
932
+
933
+ scaled_buckets.append((frames, scaled_height, scaled_width))
934
+
935
+ return scaled_buckets
936
+
937
+
938
+ @app.command()
939
+ def main( # noqa: PLR0913
940
+ dataset_file: str = typer.Argument(
941
+ ...,
942
+ help="Path to metadata file (CSV/JSON/JSONL) containing video paths",
943
+ ),
944
+ resolution_buckets: str = typer.Option(
945
+ ...,
946
+ help='Resolution buckets in format "WxHxF;WxHxF;..." (e.g. "768x768x25;512x512x49")',
947
+ ),
948
+ output_dir: str = typer.Option(
949
+ ...,
950
+ help="Output directory to save video latents",
951
+ ),
952
+ model_path: str = typer.Option(
953
+ ...,
954
+ help="Path to LTX-2 checkpoint (.safetensors file)",
955
+ ),
956
+ video_column: str = typer.Option(
957
+ default="media_path",
958
+ help="Column name in the dataset JSON/JSONL/CSV file containing video paths",
959
+ ),
960
+ batch_size: int = typer.Option(
961
+ default=1,
962
+ help="Batch size for processing",
963
+ ),
964
+ device: str = typer.Option(
965
+ default="cuda",
966
+ help="Device to use for computation",
967
+ ),
968
+ vae_tiling: bool = typer.Option(
969
+ default=False,
970
+ help="Enable VAE tiling for larger video resolutions",
971
+ ),
972
+ reshape_mode: str = typer.Option(
973
+ default="center",
974
+ help="How to crop videos: 'center' or 'random'",
975
+ ),
976
+ with_audio: bool = typer.Option(
977
+ default=False,
978
+ help="Extract and encode audio from video files",
979
+ ),
980
+ audio_output_dir: str | None = typer.Option(
981
+ default=None,
982
+ help="Output directory for audio latents (required if --with-audio is set)",
983
+ ),
984
+ ) -> None:
985
+ """Process videos/images and save latent representations for video generation training.
986
+ This script processes videos and images from metadata files and saves latent representations
987
+ that can be used for training video generation models. The output latents will maintain
988
+ the same folder structure and naming as the corresponding media files.
989
+ Examples:
990
+ # Process videos from a CSV file
991
+ python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\
992
+ --output-dir ./latents --model-path /path/to/ltx2.safetensors
993
+ # Process videos from a JSON file with custom video column
994
+ python scripts/process_videos.py dataset.json --resolution-buckets 768x768x25 \\
995
+ --output-dir ./latents --model-path /path/to/ltx2.safetensors --video-column "video_path"
996
+ # Enable VAE tiling to save GPU VRAM
997
+ python scripts/process_videos.py dataset.csv --resolution-buckets 1024x1024x25 \\
998
+ --output-dir ./latents --model-path /path/to/ltx2.safetensors --vae-tiling
999
+ # Process videos with audio
1000
+ python scripts/process_videos.py dataset.csv --resolution-buckets 768x768x25 \\
1001
+ --output-dir ./latents --model-path /path/to/ltx2.safetensors \\
1002
+ --with-audio --audio-output-dir ./audio_latents
1003
+ """
1004
+
1005
+ # Validate dataset file exists
1006
+ if not Path(dataset_file).is_file():
1007
+ raise typer.BadParameter(f"Dataset file not found: {dataset_file}")
1008
+
1009
+ # Validate audio parameters
1010
+ if with_audio and audio_output_dir is None:
1011
+ raise typer.BadParameter("--audio-output-dir is required when --with-audio is set")
1012
+
1013
+ # Parse resolution buckets
1014
+ parsed_resolution_buckets = parse_resolution_buckets(resolution_buckets)
1015
+
1016
+ if len(parsed_resolution_buckets) > 1:
1017
+ logger.warning(
1018
+ "Using multiple resolution buckets. "
1019
+ "When training with multiple resolution buckets, you must use a batch size of 1."
1020
+ )
1021
+
1022
+ # Process latents
1023
+ compute_latents(
1024
+ dataset_file=dataset_file,
1025
+ video_column=video_column,
1026
+ resolution_buckets=parsed_resolution_buckets,
1027
+ output_dir=output_dir,
1028
+ model_path=model_path,
1029
+ reshape_mode=reshape_mode,
1030
+ batch_size=batch_size,
1031
+ device=device,
1032
+ vae_tiling=vae_tiling,
1033
+ with_audio=with_audio,
1034
+ audio_output_dir=audio_output_dir,
1035
+ )
1036
+
1037
+
1038
+ if __name__ == "__main__":
1039
+ app()
packages/ltx-trainer/scripts/split_scenes.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Split video into scenes using PySceneDetect.
5
+ This script provides a command-line interface for splitting videos into scenes using various detection algorithms.
6
+ It supports multiple detection methods, preview image generation, and customizable parameters for fine-tuning
7
+ the scene detection process.
8
+ Basic usage:
9
+ # Split video using default content-based detection
10
+ scenes_split.py input.mp4 output_dir/
11
+ # Save 3 preview images per scene
12
+ scenes_split.py input.mp4 output_dir/ --save-images 3
13
+ # Process specific duration and filter short scenes
14
+ scenes_split.py input.mp4 output_dir/ --duration 60s --filter-shorter-than 2s
15
+ Advanced usage:
16
+ # Content detection with minimum scene length and frame skip
17
+ scenes_split.py input.mp4 output_dir/ --detector content --min-scene-length 30 --frame-skip 2
18
+ # Use adaptive detection with custom detector and detector parameters
19
+ scenes_split.py input.mp4 output_dir/ --detector adaptive --threshold 3.0 --adaptive-window 10
20
+ """
21
+
22
+ from enum import Enum
23
+ from pathlib import Path
24
+ from typing import List, Optional, Tuple
25
+
26
+ import typer
27
+ from scenedetect import (
28
+ AdaptiveDetector,
29
+ ContentDetector,
30
+ HistogramDetector,
31
+ SceneManager,
32
+ ThresholdDetector,
33
+ open_video,
34
+ )
35
+ from scenedetect.frame_timecode import FrameTimecode
36
+ from scenedetect.scene_manager import SceneDetector, write_scene_list_html
37
+ from scenedetect.scene_manager import save_images as save_scene_images
38
+ from scenedetect.stats_manager import StatsManager
39
+ from scenedetect.video_splitter import split_video_ffmpeg
40
+
41
+ app = typer.Typer(no_args_is_help=True, help="Split video into scenes using PySceneDetect.")
42
+
43
+
44
+ class DetectorType(str, Enum):
45
+ """Available scene detection algorithms."""
46
+
47
+ CONTENT = "content" # Detects fast cuts using HSV color space
48
+ ADAPTIVE = "adaptive" # Detects fast two-phase cuts
49
+ THRESHOLD = "threshold" # Detects fast cuts/slow fades in from and out to a given threshold level
50
+ HISTOGRAM = "histogram" # Detects based on YUV histogram differences in adjacent frames
51
+
52
+
53
+ def create_detector(
54
+ detector_type: DetectorType,
55
+ threshold: Optional[float] = None,
56
+ min_scene_len: Optional[int] = None,
57
+ luma_only: Optional[bool] = None,
58
+ adaptive_window: Optional[int] = None,
59
+ fade_bias: Optional[float] = None,
60
+ ) -> SceneDetector:
61
+ """Create a scene detector based on the specified type and parameters.
62
+ Args:
63
+ detector_type: Type of detector to create
64
+ threshold: Detection threshold (meaning varies by detector)
65
+ min_scene_len: Minimum scene length in frames
66
+ luma_only: If True, only use brightness for content detection
67
+ adaptive_window: Window size for adaptive detection
68
+ fade_bias: Bias for fade in/out detection (-1.0 to 1.0)
69
+ Note: Parameters set to None will use the detector's built-in default values.
70
+ Returns:
71
+ Configured scene detector instance
72
+ """
73
+ # Set common arguments
74
+ kwargs = {}
75
+ if threshold is not None:
76
+ kwargs["threshold"] = threshold
77
+
78
+ if min_scene_len is not None:
79
+ kwargs["min_scene_len"] = min_scene_len
80
+
81
+ match detector_type:
82
+ case DetectorType.CONTENT:
83
+ if luma_only is not None:
84
+ kwargs["luma_only"] = luma_only
85
+ return ContentDetector(**kwargs)
86
+ case DetectorType.ADAPTIVE:
87
+ if adaptive_window is not None:
88
+ kwargs["window_width"] = adaptive_window
89
+ if luma_only is not None:
90
+ kwargs["luma_only"] = luma_only
91
+ if "threshold" in kwargs:
92
+ # Special case for adaptive detector which uses different param name
93
+ kwargs["adaptive_threshold"] = kwargs.pop("threshold")
94
+ return AdaptiveDetector(**kwargs)
95
+ case DetectorType.THRESHOLD:
96
+ if fade_bias is not None:
97
+ kwargs["fade_bias"] = fade_bias
98
+ return ThresholdDetector(**kwargs)
99
+ case DetectorType.HISTOGRAM:
100
+ return HistogramDetector(**kwargs)
101
+ case _:
102
+ raise ValueError(f"Unknown detector type: {detector_type}")
103
+
104
+
105
+ def validate_output_dir(output_dir: str) -> Path:
106
+ """Validate and create output directory if it doesn't exist.
107
+ Args:
108
+ output_dir: Path to the output directory
109
+ Returns:
110
+ Path object of the validated output directory
111
+ """
112
+ path = Path(output_dir)
113
+
114
+ if path.exists() and not path.is_dir():
115
+ raise typer.BadParameter(f"{output_dir} exists but is not a directory")
116
+
117
+ return path
118
+
119
+
120
+ def parse_timecode(video: any, time_str: Optional[str]) -> Optional[FrameTimecode]:
121
+ """Parse a timecode string into a FrameTimecode object.
122
+ Supports formats:
123
+ - Frames: '123'
124
+ - Seconds: '123s' or '123.45s'
125
+ - Timecode: '00:02:03' or '00:02:03.456'
126
+ Args:
127
+ video: Video object to get framerate from
128
+ time_str: String to parse, or None
129
+ Returns:
130
+ FrameTimecode object or None if input is None
131
+ """
132
+ if time_str is None:
133
+ return None
134
+
135
+ try:
136
+ if time_str.endswith("s"):
137
+ # Seconds format
138
+ seconds = float(time_str[:-1])
139
+ return FrameTimecode(timecode=seconds, fps=video.frame_rate)
140
+ elif ":" in time_str:
141
+ # Timecode format
142
+ return FrameTimecode(timecode=time_str, fps=video.frame_rate)
143
+ else:
144
+ # Frame number format
145
+ return FrameTimecode(timecode=int(time_str), fps=video.frame_rate)
146
+ except ValueError as e:
147
+ raise typer.BadParameter(
148
+ f"Invalid timecode format: {time_str}. Use frames (123), "
149
+ f"seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn])",
150
+ ) from e
151
+
152
+
153
+ def detect_and_split_scenes( # noqa: PLR0913
154
+ video_path: str,
155
+ output_dir: Path,
156
+ detector_type: DetectorType,
157
+ threshold: Optional[float] = None,
158
+ min_scene_len: Optional[int] = None,
159
+ max_scenes: Optional[int] = None,
160
+ filter_shorter_than: Optional[str] = None,
161
+ skip_start: Optional[int] = None, # noqa: ARG001
162
+ skip_end: Optional[int] = None, # noqa: ARG001
163
+ save_images_per_scene: int = 0,
164
+ stats_file: Optional[str] = None,
165
+ luma_only: bool = False,
166
+ adaptive_window: Optional[int] = None,
167
+ fade_bias: Optional[float] = None,
168
+ downscale_factor: Optional[int] = None,
169
+ frame_skip: int = 0,
170
+ duration: Optional[str] = None,
171
+ ) -> List[Tuple[FrameTimecode, FrameTimecode]]:
172
+ """Detect and split scenes in a video using the specified parameters.
173
+ Args:
174
+ video_path: Path to input video.
175
+ output_dir: Directory to save output split scenes.
176
+ detector_type: Type of scene detector to use.
177
+ threshold: Detection threshold.
178
+ min_scene_len: Minimum scene length in frames.
179
+ max_scenes: Maximum number of scenes to detect.
180
+ filter_shorter_than: Filter out scenes shorter than this duration (frames/seconds/timecode)
181
+ skip_start: Number of frames to skip at start.
182
+ skip_end: Number of frames to skip at end.
183
+ save_images_per_scene: Number of images to save per scene (0 to disable).
184
+ stats_file: Path to save detection statistics (optional).
185
+ luma_only: Only use brightness for content detection.
186
+ adaptive_window: Window size for adaptive detection.
187
+ fade_bias: Bias for fade detection (-1.0 to 1.0).
188
+ downscale_factor: Factor to downscale frames by during detection.
189
+ frame_skip: Number of frames to skip (i.e. process every 1 in N+1 frames,
190
+ where N is frame_skip, processing only 1/N+1 percent of the video,
191
+ speeding up the detection time at the expense of accuracy).
192
+ frame_skip must be 0 (the default) when using a StatsManager.
193
+ duration: How much of the video to process from start position.
194
+ Can be specified as frames (123), seconds (123s/123.45s),
195
+ or timecode (HH:MM:SS[.nnn]).
196
+ Returns:
197
+ List of detected scenes as (start, end) FrameTimecode pairs.
198
+ """
199
+ # Create video stream
200
+ video = open_video(video_path, backend="opencv")
201
+
202
+ # Parse duration if specified
203
+ duration_tc = parse_timecode(video, duration)
204
+
205
+ # Parse filter_shorter_than if specified
206
+ filter_shorter_than_tc = parse_timecode(video, filter_shorter_than)
207
+
208
+ # Initialize scene manager with optional stats manager
209
+ stats_manager = StatsManager() if stats_file else None
210
+ scene_manager = SceneManager(stats_manager)
211
+
212
+ # Configure scene manager
213
+ if downscale_factor:
214
+ scene_manager.auto_downscale = False
215
+ scene_manager.downscale = downscale_factor
216
+
217
+ # Create and add detector
218
+ detector = create_detector(
219
+ detector_type=detector_type,
220
+ threshold=threshold,
221
+ min_scene_len=min_scene_len,
222
+ luma_only=luma_only,
223
+ adaptive_window=adaptive_window,
224
+ fade_bias=fade_bias,
225
+ )
226
+ scene_manager.add_detector(detector)
227
+
228
+ # Detect scenes
229
+ typer.echo("Detecting scenes...")
230
+ scene_manager.detect_scenes(
231
+ video=video,
232
+ show_progress=True,
233
+ frame_skip=frame_skip,
234
+ duration=duration_tc,
235
+ )
236
+
237
+ # Get scene list
238
+ scenes = scene_manager.get_scene_list()
239
+
240
+ # Filter out scenes that are too short if filter_shorter_than is specified
241
+ if filter_shorter_than_tc:
242
+ original_count = len(scenes)
243
+ scenes = [
244
+ (start, end)
245
+ for start, end in scenes
246
+ if (end.get_frames() - start.get_frames()) >= filter_shorter_than_tc.get_frames()
247
+ ]
248
+ if len(scenes) < original_count:
249
+ typer.echo(
250
+ f"Filtered out {original_count - len(scenes)} scenes shorter "
251
+ f"than {filter_shorter_than_tc.get_seconds():.1f} seconds "
252
+ f"({filter_shorter_than_tc.get_frames()} frames)",
253
+ )
254
+
255
+ # Apply max scenes limit if specified
256
+ if max_scenes and len(scenes) > max_scenes:
257
+ typer.echo(f"Dropping last {len(scenes) - max_scenes} scenes to meet max_scenes ({max_scenes}) limit")
258
+ scenes = scenes[:max_scenes]
259
+
260
+ # Print scene information
261
+ typer.echo(f"Found {len(scenes)} scenes:")
262
+ for i, (start, end) in enumerate(scenes, 1):
263
+ typer.echo(
264
+ f"Scene {i}: {start.get_timecode()} to {end.get_timecode()} "
265
+ f"({end.get_frames() - start.get_frames()} frames)",
266
+ )
267
+
268
+ # Save stats if requested
269
+ if stats_file:
270
+ typer.echo(f"Saving detection stats to {stats_file}")
271
+ stats_manager.save_to_csv(stats_file)
272
+
273
+ # Split video into scenes
274
+ typer.echo("Splitting video into scenes...")
275
+ try:
276
+ split_video_ffmpeg(
277
+ input_video_path=video_path,
278
+ scene_list=scenes,
279
+ output_dir=output_dir,
280
+ show_progress=True,
281
+ )
282
+ typer.echo(f"Scenes have been saved to: {output_dir}")
283
+ except Exception as e:
284
+ raise typer.BadParameter(f"Error splitting video: {e}") from e
285
+
286
+ # Save preview images if requested
287
+ if save_images_per_scene > 0:
288
+ typer.echo(f"Saving {save_images_per_scene} preview images per scene...")
289
+ image_filenames = save_scene_images(
290
+ scene_list=scenes,
291
+ video=video,
292
+ num_images=save_images_per_scene,
293
+ output_dir=str(output_dir),
294
+ show_progress=True,
295
+ )
296
+
297
+ # Generate HTML report with scene information and previews
298
+ html_path = output_dir / "scene_report.html"
299
+ write_scene_list_html(
300
+ output_html_filename=str(html_path),
301
+ scene_list=scenes,
302
+ image_filenames=image_filenames,
303
+ )
304
+ typer.echo(f"Scene report saved to: {html_path}")
305
+
306
+ return scenes
307
+
308
+
309
+ @app.command()
310
+ def main( # noqa: PLR0913
311
+ video_path: Path = typer.Argument( # noqa: B008
312
+ ...,
313
+ help="Path to the input video file",
314
+ exists=True,
315
+ dir_okay=False,
316
+ ),
317
+ output_dir: str = typer.Argument(
318
+ ...,
319
+ help="Directory where split scenes will be saved",
320
+ ),
321
+ detector: DetectorType = typer.Option( # noqa: B008
322
+ DetectorType.CONTENT,
323
+ help="Scene detection algorithm to use",
324
+ ),
325
+ threshold: Optional[float] = typer.Option(
326
+ None,
327
+ help="Detection threshold (meaning varies by detector)",
328
+ ),
329
+ max_scenes: Optional[int] = typer.Option(
330
+ None,
331
+ help="Maximum number of scenes to produce",
332
+ ),
333
+ min_scene_length: Optional[int] = typer.Option(
334
+ None,
335
+ help="Minimum scene length during detection. Forces the detector to make scenes at least this many frames. "
336
+ "This affects scene detection behavior but does not filter out short scenes.",
337
+ ),
338
+ filter_shorter_than: Optional[str] = typer.Option(
339
+ None,
340
+ help="Filter out scenes shorter than this duration. Can be specified as frames (123), "
341
+ "seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn]). These scenes will be detected but not saved.",
342
+ ),
343
+ skip_start: Optional[int] = typer.Option(
344
+ None,
345
+ help="Number of frames to skip at the start of the video",
346
+ ),
347
+ skip_end: Optional[int] = typer.Option(
348
+ None,
349
+ help="Number of frames to skip at the end of the video",
350
+ ),
351
+ duration: Optional[str] = typer.Option(
352
+ None,
353
+ "-d",
354
+ help="How much of the video to process. Can be specified as frames (123), "
355
+ "seconds (123s/123.45s), or timecode (HH:MM:SS[.nnn])",
356
+ ),
357
+ save_images: int = typer.Option(
358
+ 0,
359
+ help="Number of preview images to save per scene (0 to disable)",
360
+ ),
361
+ stats_file: Optional[str] = typer.Option(
362
+ None,
363
+ help="Path to save detection statistics CSV",
364
+ ),
365
+ luma_only: bool = typer.Option(
366
+ False,
367
+ help="Only use brightness for content detection",
368
+ ),
369
+ adaptive_window: Optional[int] = typer.Option(
370
+ None,
371
+ help="Window size for adaptive detection",
372
+ ),
373
+ fade_bias: Optional[float] = typer.Option(
374
+ None,
375
+ help="Bias for fade detection (-1.0 to 1.0)",
376
+ ),
377
+ downscale: Optional[int] = typer.Option(
378
+ None,
379
+ help="Factor to downscale frames by during detection",
380
+ ),
381
+ frame_skip: int = typer.Option(
382
+ 0,
383
+ help="Number of frames to skip during processing",
384
+ ),
385
+ ) -> None:
386
+ """Split video into scenes using PySceneDetect."""
387
+ if skip_start or skip_end:
388
+ typer.echo("Skipping start and end frames is not supported yet.")
389
+ return
390
+
391
+ # Validate output directory
392
+ output_path = validate_output_dir(output_dir)
393
+
394
+ # Detect and split scenes
395
+ detect_and_split_scenes(
396
+ video_path=str(video_path),
397
+ output_dir=output_path,
398
+ detector_type=detector,
399
+ threshold=threshold,
400
+ min_scene_len=min_scene_length,
401
+ max_scenes=max_scenes,
402
+ filter_shorter_than=filter_shorter_than,
403
+ skip_start=skip_start,
404
+ skip_end=skip_end,
405
+ duration=duration,
406
+ save_images_per_scene=save_images,
407
+ stats_file=stats_file,
408
+ luma_only=luma_only,
409
+ adaptive_window=adaptive_window,
410
+ fade_bias=fade_bias,
411
+ downscale_factor=downscale,
412
+ frame_skip=frame_skip,
413
+ )
414
+
415
+
416
+ if __name__ == "__main__":
417
+ app()
packages/ltx-trainer/scripts/train.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Train LTXV models using configuration from YAML files.
5
+ This script provides a command-line interface for training LTXV models using
6
+ either LoRA fine-tuning or full model fine-tuning. It loads configuration from
7
+ a YAML file and passes it to the trainer.
8
+ Basic usage:
9
+ python scripts/train.py CONFIG_PATH [--disable-progress-bars]
10
+ For multi-GPU/FSDP training, configure and launch via Accelerate:
11
+ accelerate config
12
+ accelerate launch scripts/train.py CONFIG_PATH
13
+ """
14
+
15
+ from pathlib import Path
16
+
17
+ import typer
18
+ import yaml
19
+ from rich.console import Console
20
+
21
+ from ltx_trainer.config import LtxTrainerConfig
22
+ from ltx_trainer.trainer import LtxvTrainer
23
+
24
+ console = Console()
25
+ app = typer.Typer(
26
+ pretty_exceptions_enable=False,
27
+ no_args_is_help=True,
28
+ help="Train LTXV models using configuration from YAML files.",
29
+ )
30
+
31
+
32
+ @app.command()
33
+ def main(
34
+ config_path: str = typer.Argument(..., help="Path to YAML configuration file"),
35
+ disable_progress_bars: bool = typer.Option(
36
+ False,
37
+ "--disable-progress-bars",
38
+ help="Disable progress bars (useful for multi-process runs)",
39
+ ),
40
+ ) -> None:
41
+ """Train the model using the provided configuration file."""
42
+ # Load the configuration from the YAML file
43
+ config_path = Path(config_path)
44
+ if not config_path.exists():
45
+ typer.echo(f"Error: Configuration file {config_path} does not exist.")
46
+ raise typer.Exit(code=1)
47
+
48
+ with open(config_path, "r") as file:
49
+ config_data = yaml.safe_load(file)
50
+
51
+ # Convert the loaded data to the LtxTrainerConfig object
52
+ try:
53
+ trainer_config = LtxTrainerConfig(**config_data)
54
+ except Exception as e:
55
+ typer.echo(f"Error: Invalid configuration data: {e}")
56
+ raise typer.Exit(code=1) from e
57
+
58
+ # Initialize the training process
59
+ trainer = LtxvTrainer(trainer_config)
60
+ trainer.train(disable_progress_bars=disable_progress_bars)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ app()
packages/ltx-trainer/src/ltx_trainer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.54 kB). View file
 
packages/ltx-trainer/src/ltx_trainer/__pycache__/model_loader.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
packages/ltx-trainer/src/ltx_trainer/captioning.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio-visual media captioning using multimodal models.
3
+ This module provides captioning capabilities for videos with audio using:
4
+ - Qwen2.5-Omni: Local model supporting text, audio, image, and video inputs (default)
5
+ - Gemini Flash: Cloud-based API for audio-visual captioning
6
+ Requirements:
7
+ - Qwen2.5-Omni: transformers>=4.50, torch
8
+ - Gemini Flash: google-generativeai (uv pip install google-generativeai)
9
+ Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable
10
+ """
11
+
12
+ import itertools
13
+ import re
14
+ from abc import ABC, abstractmethod
15
+ from enum import Enum
16
+ from pathlib import Path
17
+
18
+ import torch
19
+
20
+ # Instruction for audio-visual captioning (default) - includes speech transcription and sounds
21
+ DEFAULT_CAPTION_INSTRUCTION = """\
22
+ Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections:
23
+
24
+ [VISUAL]: <Detailed description of people, objects, actions, settings, colors, and movements>
25
+ [SPEECH]: <Word-for-word transcription of everything spoken.
26
+ Listen carefully and transcribe the exact words. If no speech, write "None">
27
+ [SOUNDS]: <Description of music, ambient sounds, sound effects. If none, write "None">
28
+ [TEXT]: <Any on-screen text visible. If none, write "None">
29
+
30
+ You MUST fill in all four sections. For [SPEECH], transcribe the actual words spoken, not a summary."""
31
+
32
+ # Instruction for video-only captioning (no audio processing)
33
+ VIDEO_ONLY_CAPTION_INSTRUCTION = """\
34
+ Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections:
35
+
36
+ [VISUAL]: <Detailed description of people, objects, actions, settings, colors, and movements>
37
+ [TEXT]: <Any on-screen text visible. If none, write "None">
38
+
39
+ You MUST fill in both sections."""
40
+
41
+
42
+ class CaptionerType(str, Enum):
43
+ """Enum for different types of media captioners."""
44
+
45
+ QWEN_OMNI = "qwen_omni" # Local Qwen2.5-Omni model (audio + video)
46
+ GEMINI_FLASH = "gemini_flash" # Gemini Flash API (audio + video)
47
+
48
+
49
+ def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptioningModel":
50
+ """Factory function to create a media captioner.
51
+ Args:
52
+ captioner_type: The type of captioner to create
53
+ **kwargs: Additional arguments to pass to the captioner constructor
54
+ Returns:
55
+ An instance of a MediaCaptioningModel
56
+ """
57
+ match captioner_type:
58
+ case CaptionerType.QWEN_OMNI:
59
+ return QwenOmniCaptioner(**kwargs)
60
+ case CaptionerType.GEMINI_FLASH:
61
+ return GeminiFlashCaptioner(**kwargs)
62
+ case _:
63
+ raise ValueError(f"Unsupported captioner type: {captioner_type}")
64
+
65
+
66
+ class MediaCaptioningModel(ABC):
67
+ """Abstract base class for audio-visual media captioning models."""
68
+
69
+ @abstractmethod
70
+ def caption(self, path: str | Path, **kwargs) -> str:
71
+ """Generate a caption for the given video or image.
72
+ Args:
73
+ path: Path to the video/image file to caption
74
+ Returns:
75
+ A string containing the generated caption
76
+ """
77
+
78
+ @property
79
+ @abstractmethod
80
+ def supports_audio(self) -> bool:
81
+ """Whether this captioner supports audio input."""
82
+
83
+ @staticmethod
84
+ def _is_image_file(path: str | Path) -> bool:
85
+ """Check if the file is an image based on extension."""
86
+ return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".heic", ".heif", ".webp"))
87
+
88
+ @staticmethod
89
+ def _is_video_file(path: str | Path) -> bool:
90
+ """Check if the file is a video based on extension."""
91
+ return str(path).lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
92
+
93
+ @staticmethod
94
+ def _clean_raw_caption(caption: str) -> str:
95
+ """Clean up the raw caption by removing common VLM patterns."""
96
+ start = ["The", "This"]
97
+ kind = ["video", "image", "scene", "animated sequence", "clip", "footage"]
98
+ act = ["displays", "shows", "features", "depicts", "presents", "showcases", "captures", "contains"]
99
+
100
+ for x, y, z in itertools.product(start, kind, act):
101
+ caption = caption.replace(f"{x} {y} {z} ", "", 1)
102
+
103
+ return caption
104
+
105
+
106
+ class QwenOmniCaptioner(MediaCaptioningModel):
107
+ """Audio-visual captioning using Alibaba's Qwen2.5-Omni model.
108
+ Qwen2.5-Omni is an end-to-end multimodal model that can perceive text, images, audio, and video.
109
+ It uses a Thinker-Talker architecture where the Thinker generates text and the Talker can
110
+ generate speech. For captioning, we use only the Thinker component for text generation.
111
+ Key features:
112
+ - Block-wise processing for streaming multimodal inputs
113
+ - TMRoPE (Time-aligned Multimodal RoPE) for synchronizing video and audio timestamps
114
+ - Can extract and process audio directly from video files
115
+ See: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni
116
+ Model: Qwen/Qwen2.5-Omni-7B (7B parameters)
117
+ """
118
+
119
+ MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
120
+
121
+ # Default system prompt required by Qwen2.5-Omni for proper audio processing
122
+ DEFAULT_SYSTEM_PROMPT = (
123
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
124
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech."
125
+ )
126
+
127
+ def __init__(
128
+ self,
129
+ device: str | torch.device | None = None,
130
+ use_8bit: bool = False,
131
+ instruction: str | None = None,
132
+ ):
133
+ """
134
+ Initialize the Qwen2.5-Omni captioner.
135
+ Args:
136
+ device: Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu')
137
+ use_8bit: Whether to use 8-bit quantization for reduced memory usage
138
+ instruction: Custom instruction prompt. If None, uses the default instruction
139
+ """
140
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
141
+ self.instruction = instruction
142
+ self._load_model(use_8bit=use_8bit)
143
+
144
+ @property
145
+ def supports_audio(self) -> bool:
146
+ return True
147
+
148
+ def caption(
149
+ self,
150
+ path: str | Path,
151
+ fps: int = 1,
152
+ include_audio: bool = True,
153
+ clean_caption: bool = True,
154
+ ) -> str:
155
+ """Generate a caption for the given video or image.
156
+ Args:
157
+ path: Path to the video/image file to caption
158
+ fps: Frames per second to sample from videos
159
+ include_audio: Whether to include audio in the captioning (for videos)
160
+ clean_caption: Whether to clean up the raw caption by removing common VLM patterns
161
+ Returns:
162
+ A string containing the generated caption
163
+ """
164
+ path = Path(path)
165
+ is_image = self._is_image_file(path)
166
+ is_video = self._is_video_file(path)
167
+
168
+ # Determine if we should process audio
169
+ use_audio = include_audio and is_video
170
+
171
+ # Use custom instruction if provided, otherwise pick appropriate default
172
+ if self.instruction is not None:
173
+ instruction = self.instruction
174
+ else:
175
+ instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION
176
+
177
+ # Build the user content based on media type
178
+ # Based on HuggingFace docs: https://huggingface.co/docs/transformers/en/model_doc/qwen2_5_omni
179
+ user_content = []
180
+
181
+ if is_image:
182
+ user_content.append({"type": "image", "image": str(path)})
183
+ elif is_video:
184
+ user_content.append({"type": "video", "video": str(path)})
185
+
186
+ # Add the instruction text
187
+ user_content.append({"type": "text", "text": instruction})
188
+
189
+ # Build conversation - use the default system prompt required by Qwen2.5-Omni
190
+ # Using a custom system prompt causes warnings and may affect audio processing
191
+ messages = [
192
+ {
193
+ "role": "system",
194
+ "content": [{"type": "text", "text": self.DEFAULT_SYSTEM_PROMPT}],
195
+ },
196
+ {"role": "user", "content": user_content},
197
+ ]
198
+
199
+ # Process inputs using the processor's apply_chat_template
200
+ # For videos with audio, use load_audio_from_video=True and use_audio_in_video=True
201
+ inputs = self.processor.apply_chat_template(
202
+ messages,
203
+ load_audio_from_video=use_audio,
204
+ add_generation_prompt=True,
205
+ tokenize=True,
206
+ return_dict=True,
207
+ return_tensors="pt",
208
+ fps=fps,
209
+ padding=True,
210
+ use_audio_in_video=use_audio,
211
+ ).to(self.model.device)
212
+
213
+ # Generate caption (text only, using Thinker-only model)
214
+ # Note: For Qwen2_5OmniThinkerForConditionalGeneration, use standard generate params
215
+ # (not thinker_ prefixed ones, those are for the full Qwen2_5OmniForConditionalGeneration)
216
+ input_len = inputs["input_ids"].shape[1]
217
+
218
+ output_tokens = self.model.generate(
219
+ **inputs,
220
+ use_audio_in_video=use_audio,
221
+ do_sample=False,
222
+ max_new_tokens=1024,
223
+ )
224
+
225
+ # Extract only the generated tokens (exclude the input/prompt tokens)
226
+ generated_tokens = output_tokens[:, input_len:]
227
+
228
+ # Decode only the generated response
229
+ caption_raw = self.processor.batch_decode(
230
+ generated_tokens,
231
+ skip_special_tokens=True,
232
+ clean_up_tokenization_spaces=False,
233
+ )[0]
234
+
235
+ # Remove hallucinated conversation turns (e.g., "Human\nHuman\n..." or "Human: ...")
236
+ # This is a known issue with chat models continuing to generate fake turns
237
+ # We look for patterns that are clearly hallucinated chat turns, not legitimate uses of "human"
238
+
239
+ # Match "\nHuman" followed by ":", "\n", or end of string (chat turn patterns)
240
+ # This won't match "A human walks..." or "...the human body..."
241
+ caption_raw = re.split(r"\nHuman(?::|(?:\s*\n)|$)", caption_raw, maxsplit=1)[0]
242
+ caption_raw = caption_raw.strip()
243
+
244
+ # Clean up caption if requested
245
+ return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
246
+
247
+ def _load_model(self, use_8bit: bool) -> None:
248
+ """Load the Qwen2.5-Omni model and processor.
249
+ Uses the Thinker-only model (Qwen2_5OmniThinkerForConditionalGeneration) for text generation
250
+ to save compute by not loading the audio generation components.
251
+ """
252
+ from transformers import ( # noqa: PLC0415
253
+ BitsAndBytesConfig,
254
+ Qwen2_5OmniProcessor,
255
+ Qwen2_5OmniThinkerForConditionalGeneration,
256
+ )
257
+
258
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True) if use_8bit else None
259
+
260
+ # Use Thinker-only model for text generation (saves memory by not loading Talker)
261
+ self.model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
262
+ self.MODEL_ID,
263
+ dtype=torch.bfloat16,
264
+ low_cpu_mem_usage=True,
265
+ quantization_config=quantization_config,
266
+ device_map="auto",
267
+ )
268
+
269
+ self.processor = Qwen2_5OmniProcessor.from_pretrained(self.MODEL_ID)
270
+
271
+
272
+ class GeminiFlashCaptioner(MediaCaptioningModel):
273
+ """Audio-visual captioning using Google's Gemini Flash API.
274
+ Gemini Flash is a cloud-based multimodal model that natively supports
275
+ audio and video understanding. Requires a Google API key.
276
+ Note: This captioner requires the `google-generativeai` package and a valid API key.
277
+ Set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable, or pass the key directly.
278
+ """
279
+
280
+ MODEL_ID = "gemini-flash-lite-latest"
281
+
282
+ def __init__(
283
+ self,
284
+ api_key: str | None = None,
285
+ instruction: str | None = None,
286
+ ):
287
+ """Initialize the Gemini Flash captioner.
288
+ Args:
289
+ api_key: Google API key. If not provided, will look for
290
+ GEMINI_API_KEY or GOOGLE_API_KEY environment variable.
291
+ instruction: Custom instruction prompt. If None, uses the default instruction
292
+ """
293
+ self.instruction = instruction
294
+ self._init_client(api_key)
295
+
296
+ @property
297
+ def supports_audio(self) -> bool:
298
+ return True
299
+
300
+ def caption(
301
+ self,
302
+ path: str | Path,
303
+ fps: int = 3, # noqa: ARG002 - kept for API compatibility
304
+ include_audio: bool = True,
305
+ clean_caption: bool = True,
306
+ ) -> str:
307
+ """Generate a caption for the given video or image.
308
+ Args:
309
+ path: Path to the video/image file to caption
310
+ fps: Frames per second (not used for Gemini, kept for API compatibility)
311
+ include_audio: Whether to include audio content in the caption
312
+ clean_caption: Whether to clean up the raw caption
313
+ Returns:
314
+ A string containing the generated caption
315
+ """
316
+ import time # noqa: PLC0415
317
+
318
+ path = Path(path)
319
+ is_video = self._is_video_file(path)
320
+ use_audio = include_audio and is_video
321
+
322
+ # Use custom instruction if provided, otherwise pick appropriate default
323
+ if self.instruction is not None:
324
+ instruction = self.instruction
325
+ else:
326
+ instruction = DEFAULT_CAPTION_INSTRUCTION if use_audio else VIDEO_ONLY_CAPTION_INSTRUCTION
327
+
328
+ # Upload the file to Gemini
329
+ uploaded_file = self._genai.upload_file(path)
330
+
331
+ # Wait for processing to complete (videos need time to process)
332
+ while uploaded_file.state.name == "PROCESSING":
333
+ time.sleep(1)
334
+ uploaded_file = self._genai.get_file(uploaded_file.name)
335
+
336
+ if uploaded_file.state.name == "FAILED":
337
+ raise RuntimeError(f"File processing failed: {uploaded_file.state.name}")
338
+
339
+ # Generate caption
340
+ response = self._model.generate_content([uploaded_file, instruction])
341
+
342
+ caption_raw = response.text
343
+
344
+ # Clean up the uploaded file
345
+ self._genai.delete_file(uploaded_file.name)
346
+
347
+ # Clean up caption if requested
348
+ return self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
349
+
350
+ def _init_client(self, api_key: str | None) -> None:
351
+ """Initialize the Gemini API client."""
352
+ import os # noqa: PLC0415
353
+
354
+ try:
355
+ import google.generativeai as genai # noqa: PLC0415
356
+ except ImportError as e:
357
+ raise ImportError(
358
+ "The `google-generativeai` package is required for Gemini Flash captioning. "
359
+ "Install it with: `uv pip install google-generativeai`"
360
+ ) from e
361
+
362
+ # Get API key from argument or environment
363
+ # GEMINI_API_KEY is the recommended variable, GOOGLE_API_KEY also works
364
+ resolved_api_key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
365
+
366
+ if not resolved_api_key:
367
+ raise ValueError(
368
+ "Gemini API key is required. Provide it via the `api_key` argument "
369
+ "or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable."
370
+ )
371
+
372
+ # Configure the genai library with the API key
373
+ genai.configure(api_key=resolved_api_key)
374
+
375
+ # Store reference to genai module for file operations
376
+ self._genai = genai
377
+
378
+ # Initialize the model
379
+ self._model = genai.GenerativeModel(self.MODEL_ID)
380
+
381
+
382
+ def example() -> None:
383
+ """Example usage of the captioning module."""
384
+ import sys # noqa: PLC0415
385
+
386
+ if len(sys.argv) < 2:
387
+ print(f"Usage: python {sys.argv[0]} <video_path> [captioner_type]") # noqa: T201
388
+ print(" captioner_type: qwen_omni (default) or gemini_flash") # noqa: T201
389
+ sys.exit(1)
390
+
391
+ video_path = sys.argv[1]
392
+ captioner_type = CaptionerType(sys.argv[2]) if len(sys.argv) > 2 else CaptionerType.QWEN_OMNI
393
+
394
+ print(f"Using {captioner_type.value} captioner:") # noqa: T201
395
+ captioner = create_captioner(captioner_type)
396
+ caption = captioner.caption(video_path)
397
+ print(f"CAPTION: {caption}") # noqa: T201
398
+
399
+
400
+ if __name__ == "__main__":
401
+ example()
packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: PLC0415
2
+
3
+ """
4
+ 8-bit Gemma text encoder loading utilities.
5
+ This module provides functionality for loading the Gemma text encoder in 8-bit precision
6
+ using bitsandbytes, which significantly reduces GPU memory usage.
7
+ Example usage:
8
+ from ltx_trainer.gemma_8bit import load_8bit_gemma
9
+ text_encoder = load_8bit_gemma(gemma_model_path="/path/to/gemma")
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ from collections.abc import Generator
16
+ from contextlib import contextmanager
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoder
22
+ from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
23
+
24
+
25
+ def load_8bit_gemma(gemma_model_path: str | Path, dtype: torch.dtype = torch.bfloat16) -> GemmaTextEncoder:
26
+ """Load the Gemma text encoder in 8-bit precision using bitsandbytes.
27
+ Only the Gemma LLM backbone is loaded here. The embeddings processor
28
+ (feature extractor + connectors) should be loaded separately via
29
+ :func:`ltx_trainer.model_loader.load_embeddings_processor`.
30
+ Args:
31
+ gemma_model_path: Path to Gemma model directory
32
+ dtype: Data type for non-quantized model weights
33
+ Returns:
34
+ GemmaTextEncoder with 8-bit quantized Gemma backbone
35
+ Raises:
36
+ ImportError: If bitsandbytes is not installed
37
+ FileNotFoundError: If required model files are not found
38
+ """
39
+ try:
40
+ from transformers import BitsAndBytesConfig, Gemma3ForConditionalGeneration
41
+ except ImportError as e:
42
+ raise ImportError(
43
+ "8-bit text encoder loading requires bitsandbytes. Install it with: uv pip install bitsandbytes"
44
+ ) from e
45
+
46
+ gemma_path = _find_gemma_subpath(gemma_model_path, "model*.safetensors")
47
+ tokenizer_path = _find_gemma_subpath(gemma_model_path, "tokenizer.model")
48
+
49
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
50
+ with _suppress_accelerate_memory_warnings():
51
+ gemma_model = Gemma3ForConditionalGeneration.from_pretrained(
52
+ gemma_path,
53
+ quantization_config=quantization_config,
54
+ torch_dtype=torch.bfloat16,
55
+ device_map="auto",
56
+ local_files_only=True,
57
+ )
58
+
59
+ tokenizer = LTXVGemmaTokenizer(tokenizer_path, 1024)
60
+
61
+ return GemmaTextEncoder(
62
+ tokenizer=tokenizer,
63
+ model=gemma_model,
64
+ dtype=dtype,
65
+ )
66
+
67
+
68
+ def _find_gemma_subpath(root_path: str | Path, pattern: str) -> str:
69
+ """Find a file matching a glob pattern and return its parent directory."""
70
+ matches = list(Path(root_path).rglob(pattern))
71
+ if not matches:
72
+ raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}")
73
+ return str(matches[0].parent)
74
+
75
+
76
+ @contextmanager
77
+ def _suppress_accelerate_memory_warnings() -> Generator[None, None, None]:
78
+ """Temporarily suppress INFO warnings from accelerate about memory allocation."""
79
+ accelerate_logger = logging.getLogger("accelerate.utils.modeling")
80
+ old_level = accelerate_logger.level
81
+ accelerate_logger.setLevel(logging.WARNING)
82
+ try:
83
+ yield
84
+ finally:
85
+ accelerate_logger.setLevel(old_level)
packages/ltx-trainer/src/ltx_trainer/gpu_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPU memory management utilities for training and inference."""
2
+
3
+ import functools
4
+ import gc
5
+ import subprocess
6
+ from typing import Callable, TypeVar
7
+
8
+ import torch
9
+
10
+ from ltx_trainer import logger
11
+
12
+ F = TypeVar("F", bound=Callable)
13
+
14
+
15
+ def free_gpu_memory(log: bool = False) -> None:
16
+ """Free GPU memory by running garbage collection and emptying CUDA cache.
17
+ Args:
18
+ log: If True, log memory stats after clearing
19
+ """
20
+ gc.collect()
21
+ if torch.cuda.is_available():
22
+ torch.cuda.empty_cache()
23
+ if log:
24
+ allocated = torch.cuda.memory_allocated() / 1024**3
25
+ reserved = torch.cuda.memory_reserved() / 1024**3
26
+ logger.debug(f"GPU memory freed. Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
27
+
28
+
29
+ class free_gpu_memory_context: # noqa: N801
30
+ """Context manager and decorator to free GPU memory before and/or after execution.
31
+ Can be used as a decorator:
32
+ @free_gpu_memory_context(after=True)
33
+ def my_function():
34
+ ...
35
+ Or as a context manager:
36
+ with free_gpu_memory_context():
37
+ heavy_operation()
38
+ Args:
39
+ before: Free memory before execution (default: False)
40
+ after: Free memory after execution (default: True)
41
+ log: Log memory stats when freeing (default: False)
42
+ """
43
+
44
+ def __init__(self, *, before: bool = False, after: bool = True, log: bool = False) -> None:
45
+ self.before = before
46
+ self.after = after
47
+ self.log = log
48
+
49
+ def __enter__(self) -> "free_gpu_memory_context":
50
+ if self.before:
51
+ free_gpu_memory(log=self.log)
52
+ return self
53
+
54
+ def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: object) -> None:
55
+ if self.after:
56
+ free_gpu_memory(log=self.log)
57
+
58
+ def __call__(self, func: F) -> F:
59
+ @functools.wraps(func)
60
+ def wrapper(*args, **kwargs) -> object:
61
+ with self:
62
+ return func(*args, **kwargs)
63
+
64
+ return wrapper # type: ignore
65
+
66
+
67
+ def get_gpu_memory_gb(device: torch.device) -> float:
68
+ """Get current GPU memory usage in GB using nvidia-smi.
69
+ Args:
70
+ device: torch.device to get memory usage for
71
+ Returns:
72
+ Current GPU memory usage in GB
73
+ """
74
+ try:
75
+ device_id = device.index if device.index is not None else 0
76
+ result = subprocess.check_output(
77
+ [
78
+ "nvidia-smi",
79
+ "--query-gpu=memory.used",
80
+ "--format=csv,nounits,noheader",
81
+ "-i",
82
+ str(device_id),
83
+ ],
84
+ encoding="utf-8",
85
+ )
86
+ return float(result.strip()) / 1024 # Convert MB to GB
87
+ except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
88
+ logger.error(f"Failed to get GPU memory from nvidia-smi: {e}")
89
+ # Fallback to torch
90
+ return torch.cuda.memory_allocated(device) / 1024**3
packages/ltx-trainer/src/ltx_trainer/progress.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Progress tracking for LTX training.
2
+ This module provides a unified progress display for training and validation sampling,
3
+ encapsulating all Rich progress bar logic in one place.
4
+ """
5
+
6
+ from rich.progress import (
7
+ BarColumn,
8
+ Progress,
9
+ TaskID,
10
+ TextColumn,
11
+ TimeElapsedColumn,
12
+ TimeRemainingColumn,
13
+ )
14
+
15
+
16
+ class SamplingContext:
17
+ """Context for validation sampling progress tracking.
18
+ Provides a unified progress display showing current video and denoising step.
19
+ Display format: "Sampling X/Y [████████████] step Z/W"
20
+ The progress bar shows the denoising progress for the current video.
21
+ """
22
+
23
+ def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int):
24
+ self._progress = progress
25
+ self._task = task
26
+ self._num_prompts = num_prompts
27
+ self._num_steps = num_steps
28
+
29
+ def start_video(self, video_idx: int) -> None:
30
+ """Start tracking a new video (resets step progress)."""
31
+ if self._progress is None or self._task is None:
32
+ return
33
+ # Reset task for new video: completed=0, total=num_steps
34
+ self._progress.reset(self._task, total=self._num_steps)
35
+ self._progress.update(
36
+ self._task,
37
+ completed=0,
38
+ video=f"{video_idx + 1}/{self._num_prompts}",
39
+ info=f"step 0/{self._num_steps}",
40
+ )
41
+
42
+ def advance_step(self) -> None:
43
+ """Advance the denoising step by one."""
44
+ if self._progress is None or self._task is None:
45
+ return
46
+ self._progress.advance(self._task)
47
+ completed = int(self._progress.tasks[self._task].completed)
48
+ self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")
49
+
50
+ def cleanup(self) -> None:
51
+ """Hide sampling task when done."""
52
+ if self._progress is None or self._task is None:
53
+ return
54
+ self._progress.update(self._task, visible=False)
55
+
56
+
57
+ class StandaloneSamplingProgress:
58
+ """Standalone progress display for inference scripts.
59
+ Unlike SamplingContext (which integrates with TrainingProgress), this class
60
+ manages its own Rich Progress instance for use in standalone inference scripts.
61
+ Usage:
62
+ with StandaloneSamplingProgress(num_steps=30) as ctx:
63
+ for step in range(30):
64
+ # ... denoising step ...
65
+ ctx.advance_step()
66
+ """
67
+
68
+ def __init__(self, num_steps: int, description: str = "Generating"):
69
+ """Initialize standalone sampling progress.
70
+ Args:
71
+ num_steps: Total number of denoising steps
72
+ description: Description to show in progress bar
73
+ """
74
+ self._num_steps = num_steps
75
+ self._description = description
76
+ self._progress: Progress | None = None
77
+ self._task: TaskID | None = None
78
+
79
+ def __enter__(self) -> "StandaloneSamplingProgress":
80
+ """Start the progress display."""
81
+ self._progress = Progress(
82
+ TextColumn("[progress.description]{task.description}"),
83
+ BarColumn(bar_width=40, style="blue"),
84
+ TextColumn("{task.fields[info]}", style="cyan"),
85
+ TimeElapsedColumn(),
86
+ TextColumn("ETA:"),
87
+ TimeRemainingColumn(compact=True),
88
+ )
89
+ self._progress.__enter__()
90
+ self._task = self._progress.add_task(
91
+ self._description,
92
+ total=self._num_steps,
93
+ info=f"step 0/{self._num_steps}",
94
+ )
95
+ return self
96
+
97
+ def __exit__(self, *args) -> None:
98
+ """Stop the progress display."""
99
+ if self._progress is not None:
100
+ self._progress.__exit__(*args)
101
+
102
+ def advance_step(self) -> None:
103
+ """Advance the denoising step by one."""
104
+ if self._progress is None or self._task is None:
105
+ return
106
+ self._progress.advance(self._task)
107
+ completed = int(self._progress.tasks[self._task].completed)
108
+ self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")
109
+
110
+
111
+ class TrainingProgress:
112
+ """Manages Rich progress display for training and validation.
113
+ This class encapsulates all progress bar logic, providing a clean interface
114
+ for the trainer to update progress without dealing with Rich internals.
115
+ Usage:
116
+ with TrainingProgress(enabled=True, total_steps=1000) as progress:
117
+ for step in range(1000):
118
+ # ... training step ...
119
+ progress.update_training(loss=0.1, lr=1e-4, step_time=0.5)
120
+ if should_validate:
121
+ sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30)
122
+ sampler = ValidationSampler(..., sampling_context=sampling_ctx)
123
+ for prompt_idx, prompt in enumerate(prompts):
124
+ sampling_ctx.start_video(prompt_idx)
125
+ sampler.generate(...)
126
+ sampling_ctx.cleanup()
127
+ """
128
+
129
+ def __init__(self, enabled: bool, total_steps: int):
130
+ """Initialize progress tracking.
131
+ Args:
132
+ enabled: Whether to display progress bars (False for non-main processes)
133
+ total_steps: Total number of training steps
134
+ """
135
+ self._enabled = enabled
136
+ self._total_steps = total_steps
137
+ self._train_task: TaskID | None = None
138
+
139
+ if not enabled:
140
+ self._progress = None
141
+ return
142
+
143
+ # Single Progress instance with flexible columns
144
+ self._progress = Progress(
145
+ TextColumn("[progress.description]{task.description}"),
146
+ TextColumn("{task.fields[video]}", style="magenta"),
147
+ BarColumn(bar_width=40, style="blue"),
148
+ TextColumn("{task.fields[info]}", style="cyan"),
149
+ TimeElapsedColumn(),
150
+ TextColumn("ETA:"),
151
+ TimeRemainingColumn(compact=True),
152
+ )
153
+
154
+ def __enter__(self) -> "TrainingProgress":
155
+ """Enter the progress context, starting the live display."""
156
+ if self._progress is not None:
157
+ self._progress.__enter__()
158
+ self._train_task = self._progress.add_task(
159
+ "Training",
160
+ total=self._total_steps,
161
+ video=f"0/{self._total_steps}",
162
+ info="Starting...",
163
+ )
164
+ return self
165
+
166
+ def __exit__(self, *args) -> None:
167
+ """Exit the progress context, stopping the live display."""
168
+ if self._progress is not None:
169
+ self._progress.__exit__(*args)
170
+
171
+ @property
172
+ def enabled(self) -> bool:
173
+ """Whether progress display is enabled."""
174
+ return self._enabled
175
+
176
+ def update_training(
177
+ self,
178
+ *,
179
+ loss: float,
180
+ lr: float,
181
+ step_time: float,
182
+ advance: bool = True,
183
+ ) -> None:
184
+ """Update the training progress display.
185
+ Args:
186
+ loss: Current training loss
187
+ lr: Current learning rate
188
+ step_time: Time taken for this step in seconds
189
+ advance: Whether to advance the progress by one step
190
+ """
191
+ if self._progress is None or self._train_task is None:
192
+ return
193
+
194
+ info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step"
195
+ self._progress.update(
196
+ self._train_task,
197
+ advance=1 if advance else 0,
198
+ info=info,
199
+ )
200
+ # Update step count in video column
201
+ completed = int(self._progress.tasks[self._train_task].completed)
202
+ self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}")
203
+
204
+ def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext:
205
+ """Start validation sampling progress tracking.
206
+ Creates a task that shows current video and denoising step progress.
207
+ Format: "Sampling X/Y [████████████] step Z/W"
208
+ Args:
209
+ num_prompts: Number of validation prompts to sample
210
+ num_steps: Number of denoising steps per sample
211
+ Returns:
212
+ SamplingContext for tracking progress (no-op if progress is disabled)
213
+ """
214
+ if self._progress is None:
215
+ # Return a no-op context when progress is disabled
216
+ return SamplingContext(
217
+ progress=None,
218
+ task=None,
219
+ num_prompts=num_prompts,
220
+ num_steps=num_steps,
221
+ )
222
+
223
+ task = self._progress.add_task(
224
+ "Sampling",
225
+ total=num_steps,
226
+ completed=0,
227
+ video=f"0/{num_prompts}",
228
+ info=f"step 0/{num_steps}",
229
+ )
230
+
231
+ return SamplingContext(
232
+ progress=self._progress,
233
+ task=task,
234
+ num_prompts=num_prompts,
235
+ num_steps=num_steps,
236
+ )
packages/ltx-trainer/src/ltx_trainer/quantization.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/bghira/SimpleTuner
2
+ # With improvements from: https://github.com/ostris/ai-toolkit
3
+ from typing import Literal
4
+
5
+ import torch
6
+ from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn
7
+
8
+ from ltx_trainer import logger
9
+
10
+ QuantizationOptions = Literal[
11
+ "int8-quanto",
12
+ "int4-quanto",
13
+ "int2-quanto",
14
+ "fp8-quanto",
15
+ "fp8uz-quanto",
16
+ ]
17
+
18
+ # Modules to exclude from quantization.
19
+ # These are glob patterns passed to quanto's `exclude` parameter.
20
+ # When quantizing the full model at once, these patterns match against full module paths.
21
+ # When quantizing block-by-block, we also use SKIP_ROOT_MODULES for top-level modules.
22
+ EXCLUDE_PATTERNS = [
23
+ # Input/output projection layers
24
+ "patchify_proj",
25
+ "audio_patchify_proj",
26
+ "proj_out",
27
+ "audio_proj_out",
28
+ # Timestep embedding layers - int4 tinygemm requires strict bfloat16 input
29
+ # and these receive float32 sinusoidal embeddings that are cast to bfloat16
30
+ "*adaln*",
31
+ "time_proj",
32
+ "timestep_embedder*",
33
+ # Caption/text projection layers
34
+ "caption_projection*",
35
+ "audio_caption_projection*",
36
+ # Normalization layers (usually excluded from quantization)
37
+ "*norm*",
38
+ ]
39
+
40
+ # Top-level modules to skip entirely during block-by-block quantization.
41
+ # These are exact matches against model.named_children() names.
42
+ # (Needed because quanto's exclude patterns don't work when calling quantize() directly on a module)
43
+ SKIP_ROOT_MODULES = {
44
+ "patchify_proj",
45
+ "audio_patchify_proj",
46
+ "proj_out",
47
+ "audio_proj_out",
48
+ "audio_caption_projection",
49
+ }
50
+
51
+
52
+ def quantize_model(
53
+ model: torch.nn.Module,
54
+ precision: QuantizationOptions,
55
+ quantize_activations: bool = False,
56
+ device: torch.device | str | None = None,
57
+ ) -> torch.nn.Module:
58
+ """
59
+ Quantize a model using optimum-quanto.
60
+ For large models with transformer_blocks, this function quantizes block-by-block
61
+ on GPU then moves back to CPU, which is much faster than quantizing on CPU and
62
+ uses less peak VRAM than loading the entire model to GPU at once.
63
+ Args:
64
+ model: The model to quantize.
65
+ precision: The quantization precision (e.g. "int8-quanto", "fp8-quanto").
66
+ quantize_activations: Whether to quantize activations in addition to weights.
67
+ device: Device to use for quantization. If None, uses CUDA if available, else CPU.
68
+ Returns:
69
+ The quantized model.
70
+ """
71
+ from optimum.quanto import freeze, quantize # noqa: PLC0415
72
+
73
+ if device is None:
74
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ elif isinstance(device, str):
76
+ device = torch.device(device)
77
+
78
+ weight_quant = _get_quanto_dtype(precision)
79
+
80
+ if quantize_activations:
81
+ logger.debug("Quantizing model weights and activations")
82
+ activations_quant = weight_quant
83
+ else:
84
+ activations_quant = None
85
+
86
+ # Remember original device to restore after quantization
87
+ original_device = next(model.parameters()).device
88
+
89
+ # Check if model has transformer_blocks for block-by-block quantization
90
+ if hasattr(model, "transformer_blocks"):
91
+ logger.debug("Quantizing model using block-by-block approach for memory efficiency")
92
+ _quantize_blockwise(
93
+ model,
94
+ weight_quant=weight_quant,
95
+ activations_quant=activations_quant,
96
+ device=device,
97
+ )
98
+ else:
99
+ # Fallback: quantize entire model at once
100
+ model.to(device)
101
+ quantize(model, weights=weight_quant, activations=activations_quant, exclude=EXCLUDE_PATTERNS)
102
+ freeze(model)
103
+
104
+ # Restore model to original device
105
+ model.to(original_device)
106
+
107
+ return model
108
+
109
+
110
+ def _quantize_blockwise(
111
+ model: torch.nn.Module,
112
+ weight_quant: torch.dtype,
113
+ activations_quant: torch.dtype | None,
114
+ device: torch.device,
115
+ ) -> None:
116
+ """Quantize a model block-by-block using optimum-quanto.
117
+ This approach:
118
+ 1. Moves each transformer block to GPU
119
+ 2. Quantizes on GPU (fast!)
120
+ 3. Freezes the quantized weights
121
+ 4. Moves back to CPU
122
+ This is much faster than quantizing on CPU and uses less peak VRAM
123
+ than loading the entire model to GPU.
124
+ """
125
+ from optimum.quanto import freeze, quantize # noqa: PLC0415
126
+
127
+ original_dtype = next(model.parameters()).dtype
128
+ transformer_blocks = list(model.transformer_blocks)
129
+
130
+ with Progress(
131
+ SpinnerColumn(),
132
+ TextColumn("[progress.description]{task.description}"),
133
+ BarColumn(),
134
+ TaskProgressColumn(),
135
+ transient=True,
136
+ ) as progress:
137
+ task = progress.add_task("Quantizing transformer blocks", total=len(transformer_blocks))
138
+
139
+ for block in transformer_blocks:
140
+ # Move block to GPU
141
+ block.to(device, dtype=original_dtype, non_blocking=True)
142
+
143
+ # Quantize on GPU
144
+ quantize(block, weights=weight_quant, activations=activations_quant, exclude=EXCLUDE_PATTERNS)
145
+ freeze(block)
146
+
147
+ # Move back to CPU to free up VRAM for next block
148
+ block.to("cpu", non_blocking=True)
149
+
150
+ progress.advance(task)
151
+
152
+ # Quantize remaining non-transformer-block modules (e.g., embeddings, timestep projections)
153
+ # Skip modules that should not be quantized (patchify_proj, proj_out, etc.)
154
+ logger.debug("Quantizing remaining model components")
155
+
156
+ for name, module in model.named_children():
157
+ if name == "transformer_blocks":
158
+ continue # Already quantized
159
+
160
+ if name in SKIP_ROOT_MODULES:
161
+ logger.debug(f"Skipping quantization for module: {name}")
162
+ continue # Don't quantize these modules
163
+
164
+ # Move to device, quantize, freeze, move back
165
+ module.to(device, dtype=original_dtype, non_blocking=True)
166
+ quantize(module, weights=weight_quant, activations=activations_quant, exclude=EXCLUDE_PATTERNS)
167
+ freeze(module)
168
+ module.to("cpu", non_blocking=True)
169
+
170
+
171
+ def _get_quanto_dtype(precision: QuantizationOptions) -> torch.dtype:
172
+ """Map precision string to quanto dtype."""
173
+ from optimum.quanto import ( # noqa: PLC0415
174
+ qfloat8,
175
+ qfloat8_e4m3fnuz,
176
+ qint2,
177
+ qint4,
178
+ qint8,
179
+ )
180
+
181
+ if precision == "int2-quanto":
182
+ return qint2
183
+ elif precision == "int4-quanto":
184
+ return qint4
185
+ elif precision == "int8-quanto":
186
+ return qint8
187
+ elif precision in ("fp8-quanto", "fp8uz-quanto"):
188
+ if torch.backends.mps.is_available():
189
+ raise ValueError("FP8 quantization is not supported on MPS devices. Use int2, int4, or int8 instead.")
190
+ if precision == "fp8-quanto":
191
+ return qfloat8
192
+ elif precision == "fp8uz-quanto":
193
+ return qfloat8_e4m3fnuz
194
+
195
+ raise ValueError(f"Invalid quantization precision: {precision}")
packages/ltx-trainer/src/ltx_trainer/trainer.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Callable
6
+
7
+ import torch
8
+ import wandb
9
+ import yaml
10
+ from accelerate import Accelerator, DistributedType
11
+ from accelerate.utils import set_seed
12
+ from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
13
+ from peft.tuners.tuners_utils import BaseTunerLayer
14
+ from peft.utils import ModulesToSaveWrapper
15
+ from pydantic import BaseModel
16
+ from safetensors.torch import load_file, save_file
17
+ from torch import Tensor
18
+ from torch.optim import AdamW
19
+ from torch.optim.lr_scheduler import (
20
+ CosineAnnealingLR,
21
+ CosineAnnealingWarmRestarts,
22
+ LinearLR,
23
+ LRScheduler,
24
+ PolynomialLR,
25
+ StepLR,
26
+ )
27
+ from torch.utils.data import DataLoader
28
+ from torchvision.transforms import functional as F # noqa: N812
29
+
30
+ from ltx_core.text_encoders.gemma import convert_to_additive_mask
31
+ from ltx_trainer import logger
32
+ from ltx_trainer.config import LtxTrainerConfig
33
+ from ltx_trainer.config_display import print_config
34
+ from ltx_trainer.datasets import PrecomputedDataset
35
+ from ltx_trainer.gpu_utils import free_gpu_memory, free_gpu_memory_context, get_gpu_memory_gb
36
+ from ltx_trainer.hf_hub_utils import push_to_hub
37
+ from ltx_trainer.model_loader import load_embeddings_processor, load_text_encoder
38
+ from ltx_trainer.model_loader import load_model as load_ltx_model
39
+ from ltx_trainer.progress import TrainingProgress
40
+ from ltx_trainer.quantization import quantize_model
41
+ from ltx_trainer.timestep_samplers import SAMPLERS
42
+ from ltx_trainer.training_strategies import get_training_strategy
43
+ from ltx_trainer.utils import open_image_as_srgb, save_image
44
+ from ltx_trainer.validation_sampler import CachedPromptEmbeddings, GenerationConfig, ValidationSampler
45
+ from ltx_trainer.video_utils import read_video, save_video
46
+
47
+ # Disable irrelevant warnings from transformers
48
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
49
+
50
+ # Silence bitsandbytes warnings about casting
51
+ warnings.filterwarnings(
52
+ "ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
53
+ )
54
+
55
+ # Disable progress bars if not main process
56
+ IS_MAIN_PROCESS = os.environ.get("LOCAL_RANK", "0") == "0"
57
+ if not IS_MAIN_PROCESS:
58
+ from transformers.utils.logging import disable_progress_bar
59
+
60
+ disable_progress_bar()
61
+
62
+ StepCallback = Callable[[int, int, list[Path]], None] # (step, total, list[sampled_video_path]) -> None
63
+
64
+ MEMORY_CHECK_INTERVAL = 200
65
+
66
+
67
+ class TrainingStats(BaseModel):
68
+ """Statistics collected during training"""
69
+
70
+ total_time_seconds: float
71
+ steps_per_second: float
72
+ samples_per_second: float
73
+ peak_gpu_memory_gb: float
74
+ global_batch_size: int
75
+ num_processes: int
76
+
77
+
78
+ class LtxvTrainer:
79
+ def __init__(self, trainer_config: LtxTrainerConfig) -> None:
80
+ self._config = trainer_config
81
+ if IS_MAIN_PROCESS:
82
+ print_config(trainer_config)
83
+ self._training_strategy = get_training_strategy(self._config.training_strategy)
84
+ self._cached_validation_embeddings = self._load_text_encoder_and_cache_embeddings()
85
+ self._load_models()
86
+ self._setup_accelerator()
87
+ self._collect_trainable_params()
88
+ self._load_checkpoint()
89
+ self._prepare_models_for_training()
90
+ self._dataset = None
91
+ self._global_step = -1
92
+ self._checkpoint_paths = []
93
+ self._init_wandb()
94
+
95
+ def train( # noqa: PLR0912, PLR0915
96
+ self,
97
+ disable_progress_bars: bool = False,
98
+ step_callback: StepCallback | None = None,
99
+ ) -> tuple[Path, TrainingStats]:
100
+ """
101
+ Start the training process.
102
+ Returns:
103
+ Tuple of (saved_model_path, training_stats)
104
+ """
105
+ device = self._accelerator.device
106
+ cfg = self._config
107
+ start_mem = get_gpu_memory_gb(device)
108
+
109
+ train_start_time = time.time()
110
+
111
+ # Use the same seed for all processes and ensure deterministic operations
112
+ set_seed(cfg.seed)
113
+ logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}")
114
+
115
+ self._init_optimizer()
116
+ self._init_dataloader()
117
+ data_iter = iter(self._dataloader)
118
+ self._init_timestep_sampler()
119
+
120
+ # Synchronize all processes after initialization
121
+ self._accelerator.wait_for_everyone()
122
+
123
+ Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
124
+
125
+ # Save the training configuration as YAML
126
+ self._save_config()
127
+
128
+ logger.info("🚀 Starting training...")
129
+
130
+ # Create progress tracking (disabled for non-main processes or when explicitly disabled)
131
+ progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars
132
+ progress = TrainingProgress(
133
+ enabled=progress_enabled,
134
+ total_steps=cfg.optimization.steps,
135
+ )
136
+
137
+ if IS_MAIN_PROCESS and disable_progress_bars:
138
+ logger.warning("Progress bars disabled. Intermediate status messages will be logged instead.")
139
+
140
+ self._transformer.train()
141
+ self._global_step = 0
142
+
143
+ peak_mem_during_training = start_mem
144
+
145
+ sampled_videos_paths = None
146
+
147
+ with progress:
148
+ # Initial validation before training starts
149
+ if cfg.validation.interval and not cfg.validation.skip_initial_validation:
150
+ sampled_videos_paths = self._sample_videos(progress)
151
+ if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos:
152
+ self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts)
153
+
154
+ self._accelerator.wait_for_everyone()
155
+
156
+ for step in range(cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps):
157
+ # Get next batch, reset the dataloader if needed
158
+ try:
159
+ batch = next(data_iter)
160
+ except StopIteration:
161
+ data_iter = iter(self._dataloader)
162
+ batch = next(data_iter)
163
+
164
+ step_start_time = time.time()
165
+ with self._accelerator.accumulate(self._transformer):
166
+ is_optimization_step = (step + 1) % cfg.optimization.gradient_accumulation_steps == 0
167
+ if is_optimization_step:
168
+ self._global_step += 1
169
+
170
+ loss = self._training_step(batch)
171
+ self._accelerator.backward(loss)
172
+
173
+ if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0:
174
+ self._accelerator.clip_grad_norm_(
175
+ self._trainable_params,
176
+ cfg.optimization.max_grad_norm,
177
+ )
178
+
179
+ self._optimizer.step()
180
+ self._optimizer.zero_grad()
181
+
182
+ if self._lr_scheduler is not None:
183
+ self._lr_scheduler.step()
184
+
185
+ # Run validation if needed
186
+ if (
187
+ cfg.validation.interval
188
+ and self._global_step > 0
189
+ and self._global_step % cfg.validation.interval == 0
190
+ and is_optimization_step
191
+ ):
192
+ if self._accelerator.distributed_type == DistributedType.FSDP:
193
+ # FSDP: All processes must participate in validation
194
+ sampled_videos_paths = self._sample_videos(progress)
195
+ if IS_MAIN_PROCESS and sampled_videos_paths and self._config.wandb.log_validation_videos:
196
+ self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts)
197
+ # DDP: Only main process runs validation
198
+ elif IS_MAIN_PROCESS:
199
+ sampled_videos_paths = self._sample_videos(progress)
200
+ if sampled_videos_paths and self._config.wandb.log_validation_videos:
201
+ self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts)
202
+
203
+ # Save checkpoint if needed
204
+ if (
205
+ cfg.checkpoints.interval
206
+ and self._global_step > 0
207
+ and self._global_step % cfg.checkpoints.interval == 0
208
+ and is_optimization_step
209
+ ):
210
+ self._save_checkpoint()
211
+
212
+ self._accelerator.wait_for_everyone()
213
+
214
+ # Call step callback if provided
215
+ if step_callback and is_optimization_step:
216
+ step_callback(self._global_step, cfg.optimization.steps, sampled_videos_paths)
217
+
218
+ self._accelerator.wait_for_everyone()
219
+
220
+ # Update progress and log metrics
221
+ current_lr = self._optimizer.param_groups[0]["lr"]
222
+ step_time = (time.time() - step_start_time) * cfg.optimization.gradient_accumulation_steps
223
+
224
+ progress.update_training(
225
+ loss=loss.item(),
226
+ lr=current_lr,
227
+ step_time=step_time,
228
+ advance=is_optimization_step,
229
+ )
230
+
231
+ # Log metrics to W&B (only on main process and optimization steps)
232
+ if IS_MAIN_PROCESS and is_optimization_step:
233
+ self._log_metrics(
234
+ {
235
+ "train/loss": loss.item(),
236
+ "train/learning_rate": current_lr,
237
+ "train/step_time": step_time,
238
+ "train/global_step": self._global_step,
239
+ }
240
+ )
241
+
242
+ # Fallback logging when progress bars are disabled
243
+ if disable_progress_bars and IS_MAIN_PROCESS and self._global_step % 20 == 0:
244
+ elapsed = time.time() - train_start_time
245
+ progress_percentage = self._global_step / cfg.optimization.steps
246
+ if progress_percentage > 0:
247
+ total_estimated = elapsed / progress_percentage
248
+ total_time = f"{total_estimated // 3600:.0f}h {(total_estimated % 3600) // 60:.0f}m"
249
+ else:
250
+ total_time = "calculating..."
251
+ logger.info(
252
+ f"Step {self._global_step}/{cfg.optimization.steps} - "
253
+ f"Loss: {loss.item():.4f}, LR: {current_lr:.2e}, "
254
+ f"Time/Step: {step_time:.2f}s, Total Time: {total_time}",
255
+ )
256
+
257
+ # Sample GPU memory periodically
258
+ if step % MEMORY_CHECK_INTERVAL == 0:
259
+ current_mem = get_gpu_memory_gb(device)
260
+ peak_mem_during_training = max(peak_mem_during_training, current_mem)
261
+
262
+ # Collect final stats
263
+ train_end_time = time.time()
264
+ end_mem = get_gpu_memory_gb(device)
265
+ peak_mem = max(start_mem, end_mem, peak_mem_during_training)
266
+
267
+ # Calculate steps/second over entire training
268
+ total_time_seconds = train_end_time - train_start_time
269
+ steps_per_second = cfg.optimization.steps / total_time_seconds
270
+
271
+ samples_per_second = steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size
272
+
273
+ stats = TrainingStats(
274
+ total_time_seconds=total_time_seconds,
275
+ steps_per_second=steps_per_second,
276
+ samples_per_second=samples_per_second,
277
+ peak_gpu_memory_gb=peak_mem,
278
+ num_processes=self._accelerator.num_processes,
279
+ global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes,
280
+ )
281
+
282
+ saved_path = self._save_checkpoint()
283
+
284
+ if IS_MAIN_PROCESS:
285
+ # Log the training statistics
286
+ self._log_training_stats(stats)
287
+
288
+ # Upload artifacts to hub if enabled
289
+ if cfg.hub.push_to_hub:
290
+ push_to_hub(saved_path, sampled_videos_paths, self._config)
291
+
292
+ # Log final stats to W&B
293
+ if self._wandb_run is not None:
294
+ self._log_metrics(
295
+ {
296
+ "stats/total_time_minutes": stats.total_time_seconds / 60,
297
+ "stats/steps_per_second": stats.steps_per_second,
298
+ "stats/samples_per_second": stats.samples_per_second,
299
+ "stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb,
300
+ }
301
+ )
302
+ self._wandb_run.finish()
303
+
304
+ self._accelerator.wait_for_everyone()
305
+ self._accelerator.end_training()
306
+
307
+ return saved_path, stats
308
+
309
+ def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor:
310
+ """Perform a single training step using the configured strategy."""
311
+ # Apply embedding connectors to transform pre-computed text embeddings
312
+ conditions = batch["conditions"]
313
+
314
+ if "video_prompt_embeds" in conditions:
315
+ # New format: separate video/audio features from precompute()
316
+ video_features = conditions["video_prompt_embeds"]
317
+ audio_features = conditions.get("audio_prompt_embeds")
318
+ else:
319
+ # Legacy format: single prompt_embeds tensor — duplicate for both modalities
320
+ video_features = conditions["prompt_embeds"]
321
+ audio_features = conditions["prompt_embeds"]
322
+
323
+ mask = conditions["prompt_attention_mask"]
324
+ additive_mask = convert_to_additive_mask(mask, video_features.dtype)
325
+ video_embeds, audio_embeds, attention_mask = self._embeddings_processor.create_embeddings(
326
+ video_features, audio_features, additive_mask
327
+ )
328
+
329
+ conditions["video_prompt_embeds"] = video_embeds
330
+ conditions["audio_prompt_embeds"] = audio_embeds
331
+ conditions["prompt_attention_mask"] = attention_mask
332
+
333
+ # Use strategy to prepare training inputs (returns ModelInputs with Modality objects)
334
+ model_inputs = self._training_strategy.prepare_training_inputs(batch, self._timestep_sampler)
335
+
336
+ # Run transformer forward pass with Modality-based interface
337
+ video_pred, audio_pred = self._transformer(
338
+ video=model_inputs.video,
339
+ audio=model_inputs.audio,
340
+ perturbations=None,
341
+ )
342
+
343
+ # Use strategy to compute loss
344
+ loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs)
345
+
346
+ return loss
347
+
348
+ @free_gpu_memory_context(after=True)
349
+ def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings] | None:
350
+ """Load text encoder + embeddings processor, compute and cache validation embeddings."""
351
+
352
+ # This method:
353
+ # 1. Loads the pure Gemma text encoder on GPU
354
+ # 2. Loads the embeddings processor (feature extractor + connectors)
355
+ # 3. If validation prompts are configured, computes and caches their embeddings
356
+ # 4. Unloads the Gemma model entirely, keeps the embeddings processor for training
357
+
358
+ # Load text encoder (pure Gemma LLM) on GPU
359
+ logger.debug("Loading text encoder...")
360
+ text_encoder = load_text_encoder(
361
+ gemma_model_path=self._config.model.text_encoder_path,
362
+ device="cuda",
363
+ dtype=torch.bfloat16,
364
+ load_in_8bit=self._config.acceleration.load_text_encoder_in_8bit,
365
+ )
366
+
367
+ # Load embeddings processor (feature extractor + connectors)
368
+ logger.debug("Loading embeddings processor...")
369
+ self._embeddings_processor = load_embeddings_processor(
370
+ checkpoint_path=self._config.model.model_path,
371
+ device="cuda",
372
+ dtype=torch.bfloat16,
373
+ )
374
+
375
+ # Cache validation embeddings if prompts are configured
376
+ cached_embeddings = None
377
+ if self._config.validation.prompts:
378
+ logger.info(f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts...")
379
+ cached_embeddings = []
380
+ with torch.inference_mode():
381
+ for prompt in self._config.validation.prompts:
382
+ pos_hs, pos_mask = text_encoder.encode(prompt)
383
+ pos_out = self._embeddings_processor.process_hidden_states(pos_hs, pos_mask)
384
+
385
+ neg_hs, neg_mask = text_encoder.encode(self._config.validation.negative_prompt)
386
+ neg_out = self._embeddings_processor.process_hidden_states(neg_hs, neg_mask)
387
+
388
+ cached_embeddings.append(
389
+ CachedPromptEmbeddings(
390
+ video_context_positive=pos_out.video_encoding.cpu(),
391
+ audio_context_positive=pos_out.audio_encoding.cpu(),
392
+ video_context_negative=neg_out.video_encoding.cpu(),
393
+ audio_context_negative=(
394
+ neg_out.audio_encoding.cpu() if neg_out.audio_encoding is not None else None
395
+ ),
396
+ )
397
+ )
398
+
399
+ # Unload Gemma model and feature extractor, keep only connectors for training
400
+ del text_encoder
401
+ self._embeddings_processor.feature_extractor = None
402
+
403
+ logger.debug("Validation prompt embeddings cached. Gemma model unloaded")
404
+ return cached_embeddings
405
+
406
+ def _load_models(self) -> None:
407
+ """Load the LTX-2 model components."""
408
+ # Load audio components if:
409
+ # 1. Training strategy requires audio (training the audio branch), OR
410
+ # 2. Validation is configured to generate audio (even if not training audio)
411
+ load_audio = self._training_strategy.requires_audio or self._config.validation.generate_audio
412
+
413
+ # Check if we need VAE encoder (for image or reference video conditioning)
414
+ need_vae_encoder = (
415
+ self._config.validation.images is not None or self._config.validation.reference_videos is not None
416
+ )
417
+
418
+ # Load all model components (except text encoder - already handled)
419
+ components = load_ltx_model(
420
+ checkpoint_path=self._config.model.model_path,
421
+ device="cpu",
422
+ dtype=torch.bfloat16,
423
+ with_video_vae_encoder=need_vae_encoder, # Needed for image conditioning
424
+ with_video_vae_decoder=True, # Needed for validation sampling
425
+ with_audio_vae_decoder=load_audio,
426
+ with_vocoder=load_audio,
427
+ with_text_encoder=False, # Text encoder handled separately
428
+ )
429
+
430
+ # Extract components
431
+ self._transformer = components.transformer
432
+ self._vae_decoder = components.video_vae_decoder.to(dtype=torch.bfloat16)
433
+ self._vae_encoder = components.video_vae_encoder
434
+ if self._vae_encoder is not None:
435
+ self._vae_encoder = self._vae_encoder.to(dtype=torch.bfloat16)
436
+ self._scheduler = components.scheduler
437
+ self._audio_vae = components.audio_vae_decoder
438
+ self._vocoder = components.vocoder
439
+ # Note: self._embeddings_processor was set in _load_text_encoder_and_cache_embeddings
440
+
441
+ # Determine initial dtype based on training mode.
442
+ # Note: For FSDP + LoRA, we'll cast to FP32 later in _prepare_models_for_training()
443
+ # after the accelerator is set up, and we can detect FSDP.
444
+ transformer_dtype = torch.bfloat16 if self._config.model.training_mode == "lora" else torch.float32
445
+ self._transformer = self._transformer.to(dtype=transformer_dtype)
446
+
447
+ if self._config.acceleration.quantization is not None:
448
+ if self._config.model.training_mode == "full":
449
+ raise ValueError("Quantization is not supported in full training mode.")
450
+
451
+ logger.info(f'Quantizing model with "{self._config.acceleration.quantization}". This may take a while...')
452
+ self._transformer = quantize_model(
453
+ self._transformer,
454
+ precision=self._config.acceleration.quantization,
455
+ )
456
+
457
+ # Freeze all models. We later unfreeze the transformer based on training mode.
458
+ # Note: embedding_connectors are already frozen (they come from the frozen text encoder)
459
+ self._vae_decoder.requires_grad_(False)
460
+ if self._vae_encoder is not None:
461
+ self._vae_encoder.requires_grad_(False)
462
+ self._transformer.requires_grad_(False)
463
+ if self._audio_vae is not None:
464
+ self._audio_vae.requires_grad_(False)
465
+ if self._vocoder is not None:
466
+ self._vocoder.requires_grad_(False)
467
+
468
+ def _collect_trainable_params(self) -> None:
469
+ """Collect trainable parameters based on training mode."""
470
+ if self._config.model.training_mode == "lora":
471
+ # For LoRA training, first set up LoRA layers
472
+ self._setup_lora()
473
+ elif self._config.model.training_mode == "full":
474
+ # For full training, unfreeze all transformer parameters
475
+ self._transformer.requires_grad_(True)
476
+ else:
477
+ raise ValueError(f"Unknown training mode: {self._config.model.training_mode}")
478
+
479
+ self._trainable_params = [p for p in self._transformer.parameters() if p.requires_grad]
480
+ logger.debug(f"Trainable params count: {sum(p.numel() for p in self._trainable_params):,}")
481
+
482
+ def _init_timestep_sampler(self) -> None:
483
+ """Initialize the timestep sampler based on the config."""
484
+ sampler_cls = SAMPLERS[self._config.flow_matching.timestep_sampling_mode]
485
+ self._timestep_sampler = sampler_cls(**self._config.flow_matching.timestep_sampling_params)
486
+
487
+ def _setup_lora(self) -> None:
488
+ """Configure LoRA adapters for the transformer. Only called in LoRA training mode."""
489
+ logger.debug(f"Adding LoRA adapter with rank {self._config.lora.rank}")
490
+ lora_config = LoraConfig(
491
+ r=self._config.lora.rank,
492
+ lora_alpha=self._config.lora.alpha,
493
+ target_modules=self._config.lora.target_modules,
494
+ lora_dropout=self._config.lora.dropout,
495
+ init_lora_weights=True,
496
+ )
497
+ # Wrap the transformer with PEFT to add LoRA layers
498
+ # noinspection PyTypeChecker
499
+ self._transformer = get_peft_model(self._transformer, lora_config)
500
+
501
+ def _load_checkpoint(self) -> None:
502
+ """Load checkpoint if specified in config."""
503
+ if not self._config.model.load_checkpoint:
504
+ return
505
+
506
+ checkpoint_path = self._find_checkpoint(self._config.model.load_checkpoint)
507
+ if not checkpoint_path:
508
+ logger.warning(f"⚠️ Could not find checkpoint at {self._config.model.load_checkpoint}")
509
+ return
510
+
511
+ logger.info(f"📥 Loading checkpoint from {checkpoint_path}")
512
+
513
+ if self._config.model.training_mode == "full":
514
+ self._load_full_checkpoint(checkpoint_path)
515
+ else: # LoRA mode
516
+ self._load_lora_checkpoint(checkpoint_path)
517
+
518
+ def _load_full_checkpoint(self, checkpoint_path: Path) -> None:
519
+ """Load full model checkpoint."""
520
+ state_dict = load_file(checkpoint_path)
521
+ self._transformer.load_state_dict(state_dict, strict=True)
522
+
523
+ logger.info("✅ Full model checkpoint loaded successfully")
524
+
525
+ def _load_lora_checkpoint(self, checkpoint_path: Path) -> None:
526
+ """Load LoRA checkpoint with DDP/FSDP compatibility."""
527
+ state_dict = load_file(checkpoint_path)
528
+
529
+ # Adjust layer names to match internal format.
530
+ # (Weights are saved in ComfyUI-compatible format, with "diffusion_model." prefix)
531
+ state_dict = {k.replace("diffusion_model.", "", 1): v for k, v in state_dict.items()}
532
+
533
+ # Load LoRA weights and verify all weights were loaded
534
+ base_model = self._transformer.get_base_model()
535
+ set_peft_model_state_dict(base_model, state_dict)
536
+
537
+ logger.info("✅ LoRA checkpoint loaded successfully")
538
+
539
+ def _prepare_models_for_training(self) -> None:
540
+ """Prepare models for training with Accelerate."""
541
+
542
+ # For FSDP + LoRA: Cast entire model to FP32.
543
+ # FSDP requires uniform dtype across all parameters in wrapped modules.
544
+ # In LoRA mode, PEFT creates LoRA params in FP32 while base model is BF16.
545
+ # We cast the base model to FP32 to match the LoRA params.
546
+ if self._accelerator.distributed_type == DistributedType.FSDP and self._config.model.training_mode == "lora":
547
+ logger.debug("FSDP: casting transformer to FP32 for uniform dtype")
548
+ self._transformer = self._transformer.to(dtype=torch.float32)
549
+
550
+ # Enable gradient checkpointing if requested
551
+ # For PeftModel, we need to access the underlying base model
552
+ transformer = (
553
+ self._transformer.get_base_model() if hasattr(self._transformer, "get_base_model") else self._transformer
554
+ )
555
+
556
+ transformer.set_gradient_checkpointing(self._config.optimization.enable_gradient_checkpointing)
557
+
558
+ # Keep frozen models on CPU for memory efficiency
559
+ self._vae_decoder = self._vae_decoder.to("cpu")
560
+ if self._vae_encoder is not None:
561
+ self._vae_encoder = self._vae_encoder.to("cpu")
562
+
563
+ # Embedding connectors are already on GPU from _load_text_encoder_and_cache_embeddings
564
+
565
+ # noinspection PyTypeChecker
566
+ self._transformer = self._accelerator.prepare(self._transformer)
567
+
568
+ # Log GPU memory usage after model preparation
569
+ vram_usage_gb = torch.cuda.memory_allocated() / 1024**3
570
+ logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB")
571
+
572
+ @staticmethod
573
+ def _find_checkpoint(checkpoint_path: str | Path) -> Path | None:
574
+ """Find the checkpoint file to load, handling both file and directory paths."""
575
+ checkpoint_path = Path(checkpoint_path)
576
+
577
+ if checkpoint_path.is_file():
578
+ if not checkpoint_path.suffix == ".safetensors":
579
+ raise ValueError(f"Checkpoint file must have a .safetensors extension: {checkpoint_path}")
580
+ return checkpoint_path
581
+
582
+ if checkpoint_path.is_dir():
583
+ # Look for checkpoint files in the directory
584
+ checkpoints = list(checkpoint_path.rglob("*step_*.safetensors"))
585
+
586
+ if not checkpoints:
587
+ return None
588
+
589
+ # Sort by step number and return the latest
590
+ def _get_step_num(p: Path) -> int:
591
+ try:
592
+ return int(p.stem.split("step_")[1])
593
+ except (IndexError, ValueError):
594
+ return -1
595
+
596
+ latest = max(checkpoints, key=_get_step_num)
597
+ return latest
598
+
599
+ else:
600
+ raise ValueError(f"Invalid checkpoint path: {checkpoint_path}. Must be a file or directory.")
601
+
602
+ def _init_dataloader(self) -> None:
603
+ """Initialize the training data loader using the strategy's data sources."""
604
+ if self._dataset is None:
605
+ # Get data sources from the training strategy
606
+ data_sources = self._training_strategy.get_data_sources()
607
+
608
+ self._dataset = PrecomputedDataset(self._config.data.preprocessed_data_root, data_sources=data_sources)
609
+ logger.debug(f"Loaded dataset with {len(self._dataset):,} samples from sources: {list(data_sources)}")
610
+
611
+ num_workers = self._config.data.num_dataloader_workers
612
+ dataloader = DataLoader(
613
+ self._dataset,
614
+ batch_size=self._config.optimization.batch_size,
615
+ shuffle=True,
616
+ drop_last=True,
617
+ num_workers=num_workers,
618
+ pin_memory=num_workers > 0,
619
+ persistent_workers=num_workers > 0,
620
+ )
621
+
622
+ self._dataloader = self._accelerator.prepare(dataloader)
623
+
624
+ def _init_lora_weights(self) -> None:
625
+ """Initialize LoRA weights for the transformer."""
626
+ logger.debug("Initializing LoRA weights...")
627
+ for _, module in self._transformer.named_modules():
628
+ if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
629
+ module.reset_lora_parameters(adapter_name="default", init_lora_weights=True)
630
+
631
+ def _init_optimizer(self) -> None:
632
+ """Initialize the optimizer and learning rate scheduler."""
633
+ opt_cfg = self._config.optimization
634
+
635
+ lr = opt_cfg.learning_rate
636
+ if opt_cfg.optimizer_type == "adamw":
637
+ optimizer = AdamW(self._trainable_params, lr=lr)
638
+ elif opt_cfg.optimizer_type == "adamw8bit":
639
+ # noinspection PyUnresolvedReferences
640
+ from bitsandbytes.optim import AdamW8bit # noqa: PLC0415
641
+
642
+ optimizer = AdamW8bit(self._trainable_params, lr=lr)
643
+ else:
644
+ raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}")
645
+
646
+ # Add scheduler initialization
647
+ lr_scheduler = self._create_scheduler(optimizer)
648
+
649
+ # noinspection PyTypeChecker
650
+ self._optimizer, self._lr_scheduler = self._accelerator.prepare(optimizer, lr_scheduler)
651
+
652
+ def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> LRScheduler | None:
653
+ """Create learning rate scheduler based on config."""
654
+ scheduler_type = self._config.optimization.scheduler_type
655
+ steps = self._config.optimization.steps
656
+ params = self._config.optimization.scheduler_params or {}
657
+
658
+ if scheduler_type is None:
659
+ return None
660
+
661
+ if scheduler_type == "linear":
662
+ scheduler = LinearLR(
663
+ optimizer,
664
+ start_factor=params.pop("start_factor", 1.0),
665
+ end_factor=params.pop("end_factor", 0.1),
666
+ total_iters=steps,
667
+ **params,
668
+ )
669
+ elif scheduler_type == "cosine":
670
+ scheduler = CosineAnnealingLR(
671
+ optimizer,
672
+ T_max=steps,
673
+ eta_min=params.pop("eta_min", 0),
674
+ **params,
675
+ )
676
+ elif scheduler_type == "cosine_with_restarts":
677
+ scheduler = CosineAnnealingWarmRestarts(
678
+ optimizer,
679
+ T_0=params.pop("T_0", steps // 4), # First restart cycle length
680
+ T_mult=params.pop("T_mult", 1), # Multiplicative factor for cycle lengths
681
+ eta_min=params.pop("eta_min", 5e-5),
682
+ **params,
683
+ )
684
+ elif scheduler_type == "polynomial":
685
+ scheduler = PolynomialLR(
686
+ optimizer,
687
+ total_iters=steps,
688
+ power=params.pop("power", 1.0),
689
+ **params,
690
+ )
691
+ elif scheduler_type == "step":
692
+ scheduler = StepLR(
693
+ optimizer,
694
+ step_size=params.pop("step_size", steps // 2),
695
+ gamma=params.pop("gamma", 0.1),
696
+ **params,
697
+ )
698
+ elif scheduler_type == "constant":
699
+ scheduler = None
700
+ else:
701
+ raise ValueError(f"Unknown scheduler type: {scheduler_type}")
702
+
703
+ return scheduler
704
+
705
+ def _setup_accelerator(self) -> None:
706
+ """Initialize the Accelerator with the appropriate settings."""
707
+
708
+ # All distributed setup (DDP/FSDP, number of processes, etc.) is controlled by
709
+ # the user's Accelerate configuration (accelerate config / accelerate launch).
710
+ self._accelerator = Accelerator(
711
+ mixed_precision=self._config.acceleration.mixed_precision_mode,
712
+ gradient_accumulation_steps=self._config.optimization.gradient_accumulation_steps,
713
+ )
714
+
715
+ if self._accelerator.num_processes > 1:
716
+ logger.info(
717
+ f"{self._accelerator.distributed_type.value} distributed training enabled "
718
+ f"with {self._accelerator.num_processes} processes"
719
+ )
720
+
721
+ local_batch = self._config.optimization.batch_size
722
+ global_batch = self._config.optimization.batch_size * self._accelerator.num_processes
723
+ logger.info(f"Local batch size: {local_batch}, global batch size: {global_batch}")
724
+
725
+ # Log torch.compile status from Accelerate's dynamo plugin
726
+ is_compile_enabled = (
727
+ hasattr(self._accelerator.state, "dynamo_plugin") and self._accelerator.state.dynamo_plugin.backend != "NO"
728
+ )
729
+ if is_compile_enabled:
730
+ plugin = self._accelerator.state.dynamo_plugin
731
+ logger.info(f"🔥 torch.compile enabled via Accelerate: backend={plugin.backend}, mode={plugin.mode}")
732
+
733
+ if self._accelerator.distributed_type == DistributedType.FSDP:
734
+ logger.warning(
735
+ "⚠️ FSDP + torch.compile is experimental and may hang on the first training iteration. "
736
+ "If this occurs, disable torch.compile by removing dynamo_config from your Accelerate config."
737
+ )
738
+
739
+ if self._accelerator.distributed_type == DistributedType.FSDP and self._config.acceleration.quantization:
740
+ logger.warning(
741
+ f"FSDP with quantization ({self._config.acceleration.quantization}) may have compatibility issues."
742
+ "Monitor training stability and consider disabling quantization if issues arise."
743
+ )
744
+
745
+ # Note: Use @torch.no_grad() instead of @torch.inference_mode() to avoid FSDP inplace update errors after validation
746
+ @torch.no_grad()
747
+ @free_gpu_memory_context(after=True)
748
+ def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None:
749
+ """Run validation by generating videos from validation prompts."""
750
+ use_images = self._config.validation.images is not None
751
+ use_reference_videos = self._config.validation.reference_videos is not None
752
+ generate_audio = self._config.validation.generate_audio
753
+ inference_steps = self._config.validation.inference_steps
754
+
755
+ # Zero gradients and free GPU memory to reclaim memory before validation sampling
756
+ self._optimizer.zero_grad(set_to_none=True)
757
+ free_gpu_memory()
758
+
759
+ # Start sampling progress tracking
760
+ sampling_ctx = progress.start_sampling(
761
+ num_prompts=len(self._config.validation.prompts),
762
+ num_steps=inference_steps,
763
+ )
764
+
765
+ # Create validation sampler with loaded models and progress tracking
766
+ sampler = ValidationSampler(
767
+ transformer=self._transformer,
768
+ vae_decoder=self._vae_decoder,
769
+ vae_encoder=self._vae_encoder,
770
+ text_encoder=None,
771
+ audio_decoder=self._audio_vae if generate_audio else None,
772
+ vocoder=self._vocoder if generate_audio else None,
773
+ sampling_context=sampling_ctx,
774
+ )
775
+
776
+ output_dir = Path(self._config.output_dir) / "samples"
777
+ output_dir.mkdir(exist_ok=True, parents=True)
778
+
779
+ video_paths = []
780
+ width, height, num_frames = self._config.validation.video_dims
781
+
782
+ for prompt_idx, prompt in enumerate(self._config.validation.prompts):
783
+ # Update progress to show current video
784
+ sampling_ctx.start_video(prompt_idx)
785
+
786
+ # Load conditioning image if provided
787
+ condition_image = None
788
+ if use_images:
789
+ image_path = self._config.validation.images[prompt_idx]
790
+ image = open_image_as_srgb(image_path)
791
+ # Convert PIL image to tensor [C, H, W] in [0, 1]
792
+ condition_image = F.to_tensor(image)
793
+
794
+ # Load reference video if provided (for IC-LoRA)
795
+ reference_video = None
796
+ if use_reference_videos:
797
+ ref_video_path = self._config.validation.reference_videos[prompt_idx]
798
+ # read_video returns [F, C, H, W] in [0, 1]
799
+ reference_video, _ = read_video(ref_video_path, max_frames=num_frames)
800
+
801
+ # Get cached embeddings for this prompt if available
802
+ cached_embeddings = (
803
+ self._cached_validation_embeddings[prompt_idx]
804
+ if self._cached_validation_embeddings is not None
805
+ else None
806
+ )
807
+
808
+ # Create generation config
809
+ gen_config = GenerationConfig(
810
+ prompt=prompt,
811
+ negative_prompt=self._config.validation.negative_prompt,
812
+ height=height,
813
+ width=width,
814
+ num_frames=num_frames,
815
+ frame_rate=self._config.validation.frame_rate,
816
+ num_inference_steps=inference_steps,
817
+ guidance_scale=self._config.validation.guidance_scale,
818
+ seed=self._config.validation.seed,
819
+ condition_image=condition_image,
820
+ reference_video=reference_video,
821
+ reference_downscale_factor=self._config.validation.reference_downscale_factor,
822
+ generate_audio=generate_audio,
823
+ include_reference_in_output=self._config.validation.include_reference_in_output,
824
+ cached_embeddings=cached_embeddings,
825
+ stg_scale=self._config.validation.stg_scale,
826
+ stg_blocks=self._config.validation.stg_blocks,
827
+ stg_mode=self._config.validation.stg_mode,
828
+ )
829
+
830
+ # Generate sample
831
+ video, audio = sampler.generate(
832
+ config=gen_config,
833
+ device=self._accelerator.device,
834
+ )
835
+
836
+ # Save output (image for single frame, video otherwise)
837
+ if IS_MAIN_PROCESS:
838
+ ext = "png" if num_frames == 1 else "mp4"
839
+ output_path = output_dir / f"step_{self._global_step:06d}_{prompt_idx + 1}.{ext}"
840
+ if num_frames == 1:
841
+ save_image(video, output_path)
842
+ else:
843
+ save_video(
844
+ video_tensor=video,
845
+ output_path=output_path,
846
+ fps=self._config.validation.frame_rate,
847
+ audio=audio,
848
+ audio_sample_rate=self._vocoder.output_sampling_rate if audio is not None else None,
849
+ )
850
+ video_paths.append(output_path)
851
+
852
+ # Clean up progress tasks
853
+ sampling_ctx.cleanup()
854
+
855
+ rel_outputs_path = output_dir.relative_to(self._config.output_dir)
856
+ logger.info(f"🎥 Validation samples for step {self._global_step} saved in {rel_outputs_path}")
857
+ return video_paths
858
+
859
+ @staticmethod
860
+ def _log_training_stats(stats: TrainingStats) -> None:
861
+ """Log training statistics."""
862
+ stats_str = (
863
+ "📊 Training Statistics:\n"
864
+ f" - Total time: {stats.total_time_seconds / 60:.1f} minutes\n"
865
+ f" - Training speed: {stats.steps_per_second:.2f} steps/second\n"
866
+ f" - Samples/second: {stats.samples_per_second:.2f}\n"
867
+ f" - Peak GPU memory: {stats.peak_gpu_memory_gb:.2f} GB"
868
+ )
869
+ if stats.num_processes > 1:
870
+ stats_str += f"\n - Number of processes: {stats.num_processes}\n"
871
+ stats_str += f" - Global batch size: {stats.global_batch_size}"
872
+ logger.info(stats_str)
873
+
874
+ def _save_checkpoint(self) -> Path | None:
875
+ """Save the model weights."""
876
+ is_lora = self._config.model.training_mode == "lora"
877
+ is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP
878
+
879
+ # Prepare paths
880
+ save_dir = Path(self._config.output_dir) / "checkpoints"
881
+ prefix = "lora" if is_lora else "model"
882
+ filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors"
883
+ saved_weights_path = save_dir / filename
884
+
885
+ # Get state dict (collective operation - all processes must participate)
886
+ self._accelerator.wait_for_everyone()
887
+ full_state_dict = self._accelerator.get_state_dict(self._transformer)
888
+
889
+ if not IS_MAIN_PROCESS:
890
+ return None
891
+
892
+ save_dir.mkdir(exist_ok=True, parents=True)
893
+
894
+ # Determine save precision
895
+ save_dtype = torch.bfloat16 if self._config.checkpoints.precision == "bfloat16" else torch.float32
896
+
897
+ # For LoRA: extract only adapter weights; for full: use as-is
898
+ if is_lora:
899
+ unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False)
900
+ # For FSDP, pass full_state_dict since model params aren't directly accessible
901
+ state_dict = get_peft_model_state_dict(unwrapped, state_dict=full_state_dict if is_fsdp else None)
902
+
903
+ # Remove "base_model.model." prefix added by PEFT
904
+ state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()}
905
+
906
+ # Convert to ComfyUI-compatible format (add "diffusion_model." prefix)
907
+ state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
908
+
909
+ # Cast to configured precision
910
+ state_dict = {k: v.to(save_dtype) if isinstance(v, Tensor) else v for k, v in state_dict.items()}
911
+
912
+ # Build metadata for safetensors file
913
+ metadata = self._build_checkpoint_metadata()
914
+
915
+ # Save to disk with metadata
916
+ save_file(state_dict, saved_weights_path, metadata=metadata)
917
+ else:
918
+ # Cast to configured precision
919
+ full_state_dict = {k: v.to(save_dtype) if isinstance(v, Tensor) else v for k, v in full_state_dict.items()}
920
+
921
+ # Save to disk
922
+ self._accelerator.save(full_state_dict, saved_weights_path)
923
+
924
+ rel_path = saved_weights_path.relative_to(self._config.output_dir)
925
+ logger.info(f"💾 {prefix.capitalize()} weights for step {self._global_step} saved in {rel_path}")
926
+
927
+ # Keep track of checkpoint paths, and cleanup old checkpoints if needed
928
+ self._checkpoint_paths.append(saved_weights_path)
929
+ self._cleanup_checkpoints()
930
+ return saved_weights_path
931
+
932
+ def _cleanup_checkpoints(self) -> None:
933
+ """Clean up old checkpoints."""
934
+ if 0 < self._config.checkpoints.keep_last_n < len(self._checkpoint_paths):
935
+ checkpoints_to_remove = self._checkpoint_paths[: -self._config.checkpoints.keep_last_n]
936
+ for old_checkpoint in checkpoints_to_remove:
937
+ if old_checkpoint.exists():
938
+ old_checkpoint.unlink()
939
+ logger.info(f"Removed old checkpoints: {old_checkpoint}")
940
+ # Update the list to only contain kept checkpoints
941
+ self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :]
942
+
943
+ def _build_checkpoint_metadata(self) -> dict[str, str]:
944
+ """Build metadata dictionary for safetensors checkpoint.
945
+ Delegates to the training strategy to get strategy-specific metadata
946
+ that downstream inference pipelines may need.
947
+ Returns:
948
+ Dictionary of string key-value pairs for safetensors metadata.
949
+ Values are converted to strings for safetensors compatibility.
950
+ """
951
+ raw_metadata = self._training_strategy.get_checkpoint_metadata()
952
+ # Convert all values to strings for safetensors compatibility
953
+ metadata = {k: str(v) for k, v in raw_metadata.items()}
954
+ if metadata:
955
+ logger.info(f"Saving checkpoint metadata: {metadata}")
956
+ return metadata
957
+
958
+ def _save_config(self) -> None:
959
+ """Save the training configuration as a YAML file in the output directory."""
960
+ if not IS_MAIN_PROCESS:
961
+ return
962
+
963
+ config_path = Path(self._config.output_dir) / "training_config.yaml"
964
+ with open(config_path, "w") as f:
965
+ yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2)
966
+
967
+ logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}")
968
+
969
+ def _init_wandb(self) -> None:
970
+ """Initialize Weights & Biases run."""
971
+ if not self._config.wandb.enabled or not IS_MAIN_PROCESS:
972
+ self._wandb_run = None
973
+ return
974
+
975
+ wandb_config = self._config.wandb
976
+ run = wandb.init(
977
+ project=wandb_config.project,
978
+ entity=wandb_config.entity,
979
+ name=Path(self._config.output_dir).name,
980
+ tags=wandb_config.tags,
981
+ config=self._config.model_dump(),
982
+ )
983
+ self._wandb_run = run
984
+
985
+ def _log_metrics(self, metrics: dict[str, float]) -> None:
986
+ """Log metrics to Weights & Biases."""
987
+ if self._wandb_run is not None:
988
+ self._wandb_run.log(metrics)
989
+
990
+ def _log_validation_samples(self, sample_paths: list[Path], prompts: list[str]) -> None:
991
+ """Log validation samples (videos or images) to Weights & Biases."""
992
+ if not self._config.wandb.log_validation_videos or self._wandb_run is None:
993
+ return
994
+
995
+ # Determine if outputs are images or videos based on file extension
996
+ is_image = sample_paths and sample_paths[0].suffix.lower() in (".png", ".jpg", ".jpeg", ".heic", ".webp")
997
+ media_cls = wandb.Image if is_image else wandb.Video
998
+
999
+ samples = [media_cls(str(path), caption=prompt) for path, prompt in zip(sample_paths, prompts, strict=True)]
1000
+ self._wandb_run.log({"validation_samples": samples}, step=self._global_step)
packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training strategies for different conditioning modes.
2
+ This package implements the Strategy Pattern to handle different training modes:
3
+ - Text-to-video training (standard generation, optionally with audio)
4
+ - Video-to-video training (IC-LoRA mode with reference videos)
5
+ Each strategy encapsulates the specific logic for preparing model inputs and computing loss.
6
+ """
7
+
8
+ from ltx_trainer import logger
9
+ from ltx_trainer.training_strategies.base_strategy import (
10
+ DEFAULT_FPS,
11
+ VIDEO_SCALE_FACTORS,
12
+ ModelInputs,
13
+ TrainingStrategy,
14
+ TrainingStrategyConfigBase,
15
+ )
16
+ from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig, TextToVideoStrategy
17
+ from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig, VideoToVideoStrategy
18
+
19
+ # Type alias for all strategy config types
20
+ TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig
21
+
22
+ __all__ = [
23
+ "DEFAULT_FPS",
24
+ "VIDEO_SCALE_FACTORS",
25
+ "ModelInputs",
26
+ "TextToVideoConfig",
27
+ "TextToVideoStrategy",
28
+ "TrainingStrategy",
29
+ "TrainingStrategyConfig",
30
+ "TrainingStrategyConfigBase",
31
+ "VideoToVideoConfig",
32
+ "VideoToVideoStrategy",
33
+ "get_training_strategy",
34
+ ]
35
+
36
+
37
+ def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy:
38
+ """Factory function to create the appropriate training strategy.
39
+ The strategy is determined by the `name` field in the configuration.
40
+ Args:
41
+ config: Strategy-specific configuration with a `name` field
42
+ Returns:
43
+ The appropriate training strategy instance
44
+ Raises:
45
+ ValueError: If strategy name is not supported
46
+ """
47
+
48
+ match config:
49
+ case TextToVideoConfig():
50
+ strategy = TextToVideoStrategy(config)
51
+ case VideoToVideoConfig():
52
+ strategy = VideoToVideoStrategy(config)
53
+ case _:
54
+ raise ValueError(f"Unknown training strategy config type: {type(config).__name__}")
55
+
56
+ audio_mode = "(audio enabled 🔈)" if getattr(config, "with_audio", False) else "(audio disabled 🔇)"
57
+ logger.debug(f"🎯 Using {strategy.__class__.__name__} training strategy {audio_mode}")
58
+ return strategy
packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for training strategies.
2
+ This module defines the abstract base class that all training strategies must implement,
3
+ along with the base configuration class.
4
+ """
5
+
6
+ import random
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from typing import Any, Literal
10
+
11
+ import torch
12
+ from pydantic import BaseModel, ConfigDict, Field
13
+ from torch import Tensor
14
+
15
+ from ltx_core.components.patchifiers import (
16
+ AudioPatchifier,
17
+ VideoLatentPatchifier,
18
+ get_pixel_coords,
19
+ )
20
+ from ltx_core.model.transformer.modality import Modality
21
+ from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
22
+ from ltx_trainer.timestep_samplers import TimestepSampler
23
+
24
+ # Default frames per second for video missing in the FPS metadata
25
+ DEFAULT_FPS = 24
26
+
27
+ # VAE scale factors for LTX-2
28
+ VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
29
+
30
+
31
+ class TrainingStrategyConfigBase(BaseModel):
32
+ """Base configuration class for training strategies.
33
+ All strategy-specific configuration classes should inherit from this.
34
+ """
35
+
36
+ model_config = ConfigDict(extra="forbid")
37
+
38
+ name: Literal["text_to_video", "video_to_video"] = Field(
39
+ description="Unique name identifying the training strategy type"
40
+ )
41
+
42
+
43
+ @dataclass
44
+ class ModelInputs:
45
+ """Container for model inputs using the Modality-based interface."""
46
+
47
+ video: Modality
48
+ audio: Modality | None
49
+
50
+ # Training targets (for loss computation)
51
+ video_targets: Tensor
52
+ audio_targets: Tensor | None
53
+
54
+ # Masks for loss computation
55
+ video_loss_mask: Tensor # Boolean mask: True = compute loss for this token
56
+ audio_loss_mask: Tensor | None
57
+
58
+ # Metadata needed for loss computation in some strategies
59
+ ref_seq_len: int | None = None # For IC-LoRA: length of reference sequence
60
+
61
+
62
+ class TrainingStrategy(ABC):
63
+ """Abstract base class for training strategies.
64
+ Each strategy encapsulates the logic for a specific training mode,
65
+ handling input preparation and loss computation.
66
+ """
67
+
68
+ def __init__(self, config: TrainingStrategyConfigBase):
69
+ """Initialize strategy with configuration.
70
+ Args:
71
+ config: Strategy-specific configuration
72
+ """
73
+ self.config = config
74
+ self._video_patchifier = VideoLatentPatchifier(patch_size=1)
75
+ self._audio_patchifier = AudioPatchifier(patch_size=1)
76
+
77
+ @property
78
+ def requires_audio(self) -> bool:
79
+ """Whether this training strategy requires audio components.
80
+ Override this property in subclasses that support audio training.
81
+ The trainer uses this to determine whether to load audio VAE and vocoder.
82
+ Returns:
83
+ True if audio components should be loaded, False otherwise.
84
+ """
85
+ return False
86
+
87
+ @abstractmethod
88
+ def get_data_sources(self) -> list[str] | dict[str, str]:
89
+ """Get the required data sources for this training strategy.
90
+ Returns:
91
+ Either a list of data directory names (where output keys match directory names)
92
+ or a dictionary mapping data directory names to custom output keys for the dataset
93
+ """
94
+
95
+ @abstractmethod
96
+ def prepare_training_inputs(
97
+ self,
98
+ batch: dict[str, Any],
99
+ timestep_sampler: TimestepSampler,
100
+ ) -> ModelInputs:
101
+ """Prepare training inputs from a raw data batch.
102
+ Args:
103
+ batch: Raw batch data from the dataset. Contains:
104
+ - "latents": Video latent data
105
+ - "conditions": Text embeddings with keys:
106
+ - "video_prompt_embeds": Already processed by embedding connectors
107
+ - "audio_prompt_embeds": Already processed by embedding connectors
108
+ - "prompt_attention_mask": Attention mask
109
+ - Additional keys depending on strategy (e.g., "ref_latents" for IC-LoRA)
110
+ timestep_sampler: Sampler for generating timesteps and noise
111
+ Returns:
112
+ ModelInputs containing Modality objects and training targets
113
+ """
114
+
115
+ @abstractmethod
116
+ def compute_loss(
117
+ self,
118
+ video_pred: Tensor,
119
+ audio_pred: Tensor | None,
120
+ inputs: ModelInputs,
121
+ ) -> Tensor:
122
+ """Compute the training loss.
123
+ Args:
124
+ video_pred: Video prediction from the transformer model
125
+ audio_pred: Audio prediction from the transformer model (None for video-only)
126
+ inputs: The prepared model inputs containing targets and masks
127
+ Returns:
128
+ Scalar loss tensor
129
+ """
130
+
131
+ def get_checkpoint_metadata(self) -> dict[str, Any]:
132
+ """Get strategy-specific metadata to include in checkpoint files.
133
+ Override this method in subclasses to add custom metadata,
134
+ e.g. any parameters that a downstream inference pipeline may need.
135
+ Returns:
136
+ Dictionary of metadata key-value pairs (values must be JSON-serializable)
137
+ """
138
+ return {}
139
+
140
+ def _get_video_positions(
141
+ self,
142
+ num_frames: int,
143
+ height: int,
144
+ width: int,
145
+ batch_size: int,
146
+ fps: float,
147
+ device: torch.device,
148
+ dtype: torch.dtype,
149
+ ) -> Tensor:
150
+ """Generate video position embeddings using ltx_core's native implementation.
151
+ Args:
152
+ num_frames: Number of latent frames
153
+ height: Latent height
154
+ width: Latent width
155
+ batch_size: Batch size
156
+ fps: Frames per second
157
+ device: Target device
158
+ dtype: Target dtype
159
+ Returns:
160
+ Position tensor of shape [B, 3, seq_len, 2]
161
+ """
162
+ latent_coords = self._video_patchifier.get_patch_grid_bounds(
163
+ output_shape=VideoLatentShape(
164
+ frames=num_frames,
165
+ height=height,
166
+ width=width,
167
+ batch=batch_size,
168
+ channels=128, # Video latent channels
169
+ ),
170
+ device=device,
171
+ )
172
+
173
+ # Convert latent coords to pixel coords with causal fix
174
+ pixel_coords = get_pixel_coords(
175
+ latent_coords=latent_coords,
176
+ scale_factors=VIDEO_SCALE_FACTORS,
177
+ causal_fix=True,
178
+ ).to(dtype)
179
+
180
+ # Scale temporal dimension by 1/fps to get time in seconds
181
+ pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps
182
+
183
+ return pixel_coords
184
+
185
+ def _get_audio_positions(
186
+ self,
187
+ num_time_steps: int,
188
+ batch_size: int,
189
+ device: torch.device,
190
+ dtype: torch.dtype,
191
+ ) -> Tensor:
192
+ """Generate audio position embeddings using ltx_core's native implementation.
193
+ Args:
194
+ num_time_steps: Number of audio time steps (T, not T*mel_bins)
195
+ batch_size: Batch size
196
+ device: Target device
197
+ dtype: Target dtype
198
+ Returns:
199
+ Position tensor of shape [B, 1, num_time_steps, 2]
200
+ Note:
201
+ Audio latents should be in patchified format [B, T, C*F] = [B, T, 128]
202
+ where T is the number of time steps, C=8 channels, F=16 mel bins.
203
+ This matches the format produced by AudioPatchifier.patchify().
204
+ """
205
+ mel_bins = 16
206
+
207
+ latent_coords = self._audio_patchifier.get_patch_grid_bounds(
208
+ output_shape=AudioLatentShape(
209
+ frames=num_time_steps,
210
+ mel_bins=mel_bins,
211
+ batch=batch_size,
212
+ channels=8, # Audio latent channels
213
+ ),
214
+ device=device,
215
+ )
216
+
217
+ return latent_coords.to(dtype)
218
+
219
+ @staticmethod
220
+ def _create_per_token_timesteps(conditioning_mask: Tensor, sampled_sigma: Tensor) -> Tensor:
221
+ """Create per-token timesteps based on conditioning mask.
222
+ Args:
223
+ conditioning_mask: Boolean mask of shape (batch_size, sequence_length),
224
+ where True = conditioning token (timestep=0), False = target token (use sigma)
225
+ sampled_sigma: Sampled sigma values of shape (batch_size,) or (batch_size, 1, 1)
226
+ Returns:
227
+ Timesteps tensor of shape [batch_size, sequence_length]
228
+ """
229
+ # Expand to match conditioning mask shape [B, seq_len]
230
+ expanded_sigma = sampled_sigma.view(-1, 1).expand_as(conditioning_mask)
231
+
232
+ # Conditioning tokens get 0, target tokens get the sampled sigma
233
+ return torch.where(conditioning_mask, torch.zeros_like(expanded_sigma), expanded_sigma)
234
+
235
+ @staticmethod
236
+ def _create_first_frame_conditioning_mask(
237
+ batch_size: int,
238
+ sequence_length: int,
239
+ height: int,
240
+ width: int,
241
+ device: torch.device,
242
+ first_frame_conditioning_p: float = 0.0,
243
+ ) -> Tensor:
244
+ """Create conditioning mask for first frame conditioning.
245
+ Args:
246
+ batch_size: Batch size
247
+ sequence_length: Total sequence length
248
+ height: Latent height
249
+ width: Latent width
250
+ device: Target device
251
+ first_frame_conditioning_p: Probability of conditioning on the first frame
252
+ Returns:
253
+ Boolean mask where True indicates first frame tokens (if conditioning is enabled)
254
+ """
255
+ conditioning_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device)
256
+
257
+ if first_frame_conditioning_p > 0 and random.random() < first_frame_conditioning_p:
258
+ first_frame_end_idx = height * width
259
+ if first_frame_end_idx < sequence_length:
260
+ conditioning_mask[:, :first_frame_end_idx] = True
261
+
262
+ return conditioning_mask
packages/ltx-trainer/src/ltx_trainer/training_strategies/text_to_video.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text-to-video training strategy.
2
+ This strategy implements standard text-to-video generation training where:
3
+ - Only target latents are used (no reference videos)
4
+ - Standard noise application and loss computation
5
+ - Supports first frame conditioning
6
+ - Optionally supports joint audio-video training
7
+ """
8
+
9
+ from typing import Any, Literal
10
+
11
+ import torch
12
+ from pydantic import Field
13
+ from torch import Tensor
14
+
15
+ from ltx_core.model.transformer.modality import Modality
16
+ from ltx_trainer import logger
17
+ from ltx_trainer.timestep_samplers import TimestepSampler
18
+ from ltx_trainer.training_strategies.base_strategy import (
19
+ DEFAULT_FPS,
20
+ ModelInputs,
21
+ TrainingStrategy,
22
+ TrainingStrategyConfigBase,
23
+ )
24
+
25
+
26
+ class TextToVideoConfig(TrainingStrategyConfigBase):
27
+ """Configuration for text-to-video training strategy."""
28
+
29
+ name: Literal["text_to_video"] = "text_to_video"
30
+
31
+ first_frame_conditioning_p: float = Field(
32
+ default=0.1,
33
+ description="Probability of conditioning on the first frame during training",
34
+ ge=0.0,
35
+ le=1.0,
36
+ )
37
+
38
+ with_audio: bool = Field(
39
+ default=False,
40
+ description="Whether to include audio in training (joint audio-video generation)",
41
+ )
42
+
43
+ audio_latents_dir: str = Field(
44
+ default="audio_latents",
45
+ description="Directory name for audio latents when with_audio is True",
46
+ )
47
+
48
+
49
+ class TextToVideoStrategy(TrainingStrategy):
50
+ """Text-to-video training strategy.
51
+ This strategy implements regular video generation training where:
52
+ - Only target latents are used (no reference videos)
53
+ - Standard noise application and loss computation
54
+ - Supports first frame conditioning
55
+ - Optionally supports joint audio-video training when with_audio=True
56
+ """
57
+
58
+ config: TextToVideoConfig
59
+
60
+ def __init__(self, config: TextToVideoConfig):
61
+ """Initialize strategy with configuration.
62
+ Args:
63
+ config: Text-to-video configuration
64
+ """
65
+ super().__init__(config)
66
+
67
+ @property
68
+ def requires_audio(self) -> bool:
69
+ """Whether this training strategy requires audio components."""
70
+ return self.config.with_audio
71
+
72
+ def get_data_sources(self) -> list[str] | dict[str, str]:
73
+ """
74
+ Text-to-video training requires latents and text conditions.
75
+ When with_audio is True, also requires audio latents.
76
+ """
77
+ sources = {
78
+ "latents": "latents",
79
+ "conditions": "conditions",
80
+ }
81
+
82
+ if self.config.with_audio:
83
+ sources[self.config.audio_latents_dir] = "audio_latents"
84
+
85
+ return sources
86
+
87
+ def prepare_training_inputs(
88
+ self,
89
+ batch: dict[str, Any],
90
+ timestep_sampler: TimestepSampler,
91
+ ) -> ModelInputs:
92
+ """Prepare inputs for text-to-video training."""
93
+ # Get pre-encoded latents - dataset provides uniform non-patchified format [B, C, F, H, W]
94
+ latents = batch["latents"]
95
+ video_latents = latents["latents"]
96
+
97
+ # Get video dimensions (assume same for all batch elements)
98
+ num_frames = latents["num_frames"][0].item()
99
+ height = latents["height"][0].item()
100
+ width = latents["width"][0].item()
101
+
102
+ # Patchify latents: [B, C, F, H, W] -> [B, seq_len, C]
103
+ video_latents = self._video_patchifier.patchify(video_latents)
104
+
105
+ # Handle FPS with backward compatibility
106
+ fps = latents.get("fps", None)
107
+ if fps is not None and not torch.all(fps == fps[0]):
108
+ logger.warning(
109
+ f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}"
110
+ )
111
+ fps = fps[0].item() if fps is not None else DEFAULT_FPS
112
+
113
+ # Get text embeddings (already processed by embedding connectors in trainer)
114
+ conditions = batch["conditions"]
115
+ video_prompt_embeds = conditions["video_prompt_embeds"]
116
+ audio_prompt_embeds = conditions["audio_prompt_embeds"]
117
+ prompt_attention_mask = conditions["prompt_attention_mask"]
118
+
119
+ batch_size = video_latents.shape[0]
120
+ video_seq_len = video_latents.shape[1]
121
+ device = video_latents.device
122
+ dtype = video_latents.dtype
123
+
124
+ # Create conditioning mask (first frame conditioning)
125
+ video_conditioning_mask = self._create_first_frame_conditioning_mask(
126
+ batch_size=batch_size,
127
+ sequence_length=video_seq_len,
128
+ height=height,
129
+ width=width,
130
+ device=device,
131
+ first_frame_conditioning_p=self.config.first_frame_conditioning_p,
132
+ )
133
+
134
+ # Sample noise and sigmas
135
+ sigmas = timestep_sampler.sample_for(video_latents)
136
+ video_noise = torch.randn_like(video_latents)
137
+
138
+ # Apply noise: noisy = (1 - sigma) * clean + sigma * noise
139
+ sigmas_expanded = sigmas.view(-1, 1, 1)
140
+ noisy_video = (1 - sigmas_expanded) * video_latents + sigmas_expanded * video_noise
141
+
142
+ # For conditioning tokens, use clean latents
143
+ conditioning_mask_expanded = video_conditioning_mask.unsqueeze(-1)
144
+ noisy_video = torch.where(conditioning_mask_expanded, video_latents, noisy_video)
145
+
146
+ # Compute video targets (velocity prediction)
147
+ video_targets = video_noise - video_latents
148
+
149
+ # Create per-token timesteps
150
+ video_timesteps = self._create_per_token_timesteps(video_conditioning_mask, sigmas.squeeze())
151
+
152
+ # Generate video positions using ltx_core's native implementation
153
+ video_positions = self._get_video_positions(
154
+ num_frames=num_frames,
155
+ height=height,
156
+ width=width,
157
+ batch_size=batch_size,
158
+ fps=fps,
159
+ device=device,
160
+ dtype=dtype,
161
+ )
162
+
163
+ # Create video Modality
164
+ video_modality = Modality(
165
+ enabled=True,
166
+ sigma=sigmas,
167
+ latent=noisy_video,
168
+ timesteps=video_timesteps,
169
+ positions=video_positions,
170
+ context=video_prompt_embeds,
171
+ context_mask=prompt_attention_mask,
172
+ )
173
+
174
+ # Video loss mask: True for tokens we want to compute loss on (non-conditioning tokens)
175
+ video_loss_mask = ~video_conditioning_mask
176
+
177
+ # Handle audio if enabled
178
+ audio_modality = None
179
+ audio_targets = None
180
+ audio_loss_mask = None
181
+
182
+ if self.config.with_audio:
183
+ audio_modality, audio_targets, audio_loss_mask = self._prepare_audio_inputs(
184
+ batch=batch,
185
+ sigmas=sigmas,
186
+ audio_prompt_embeds=audio_prompt_embeds,
187
+ prompt_attention_mask=prompt_attention_mask,
188
+ batch_size=batch_size,
189
+ device=device,
190
+ dtype=dtype,
191
+ )
192
+
193
+ return ModelInputs(
194
+ video=video_modality,
195
+ audio=audio_modality,
196
+ video_targets=video_targets,
197
+ audio_targets=audio_targets,
198
+ video_loss_mask=video_loss_mask,
199
+ audio_loss_mask=audio_loss_mask,
200
+ )
201
+
202
+ def _prepare_audio_inputs(
203
+ self,
204
+ batch: dict[str, Any],
205
+ sigmas: Tensor,
206
+ audio_prompt_embeds: Tensor,
207
+ prompt_attention_mask: Tensor,
208
+ batch_size: int,
209
+ device: torch.device,
210
+ dtype: torch.dtype,
211
+ ) -> tuple[Modality, Tensor, Tensor]:
212
+ """Prepare audio inputs for joint audio-video training.
213
+ Args:
214
+ batch: Raw batch data containing audio_latents
215
+ sigmas: Sampled sigma values (same as video)
216
+ audio_prompt_embeds: Audio context embeddings
217
+ prompt_attention_mask: Attention mask for context
218
+ batch_size: Batch size
219
+ device: Target device
220
+ dtype: Target dtype
221
+ Returns:
222
+ Tuple of (audio_modality, audio_targets, audio_loss_mask)
223
+ """
224
+ # Get audio latents - dataset provides uniform non-patchified format [B, C, T, F]
225
+ audio_data = batch["audio_latents"]
226
+ audio_latents = audio_data["latents"]
227
+
228
+ # Patchify audio latents: [B, C, T, F] -> [B, T, C*F]
229
+ audio_latents = self._audio_patchifier.patchify(audio_latents)
230
+
231
+ audio_seq_len = audio_latents.shape[1]
232
+
233
+ # Sample audio noise
234
+ audio_noise = torch.randn_like(audio_latents)
235
+
236
+ # Apply noise to audio (same sigma as video)
237
+ sigmas_expanded = sigmas.view(-1, 1, 1)
238
+ noisy_audio = (1 - sigmas_expanded) * audio_latents + sigmas_expanded * audio_noise
239
+
240
+ # Compute audio targets
241
+ audio_targets = audio_noise - audio_latents
242
+
243
+ # Audio timesteps: all tokens use the sampled sigma (no conditioning mask)
244
+ audio_timesteps = sigmas.view(-1, 1).expand(-1, audio_seq_len)
245
+
246
+ # Generate audio positions
247
+ audio_positions = self._get_audio_positions(
248
+ num_time_steps=audio_seq_len,
249
+ batch_size=batch_size,
250
+ device=device,
251
+ dtype=dtype,
252
+ )
253
+
254
+ # Create audio Modality
255
+ audio_modality = Modality(
256
+ enabled=True,
257
+ latent=noisy_audio,
258
+ sigma=sigmas,
259
+ timesteps=audio_timesteps,
260
+ positions=audio_positions,
261
+ context=audio_prompt_embeds,
262
+ context_mask=prompt_attention_mask,
263
+ )
264
+
265
+ # Audio loss mask: all tokens contribute to loss (no conditioning)
266
+ audio_loss_mask = torch.ones(batch_size, audio_seq_len, dtype=torch.bool, device=device)
267
+
268
+ return audio_modality, audio_targets, audio_loss_mask
269
+
270
+ def compute_loss(
271
+ self,
272
+ video_pred: Tensor,
273
+ audio_pred: Tensor | None,
274
+ inputs: ModelInputs,
275
+ ) -> Tensor:
276
+ """Compute masked MSE loss for video and optionally audio."""
277
+ # Video loss
278
+ video_loss = (video_pred - inputs.video_targets).pow(2)
279
+ video_loss_mask = inputs.video_loss_mask.unsqueeze(-1).float()
280
+ video_loss = video_loss.mul(video_loss_mask).div(video_loss_mask.mean())
281
+ video_loss = video_loss.mean()
282
+
283
+ # If no audio, return video loss only
284
+ if not self.config.with_audio or audio_pred is None or inputs.audio_targets is None:
285
+ return video_loss
286
+
287
+ # Audio loss (no conditioning mask)
288
+ audio_loss = (audio_pred - inputs.audio_targets).pow(2).mean()
289
+
290
+ # Combined loss
291
+ return video_loss + audio_loss
packages/ltx-trainer/src/ltx_trainer/training_strategies/video_to_video.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video-to-video training strategy for IC-LoRA.
2
+ This strategy implements training with reference video conditioning where:
3
+ - Reference latents (clean) are concatenated with target latents (noised)
4
+ - Video coordinates handle both reference and target sequences
5
+ - Loss is computed only on the target portion
6
+ """
7
+
8
+ from typing import Any, Literal
9
+
10
+ import torch
11
+ from pydantic import Field
12
+ from torch import Tensor
13
+
14
+ from ltx_core.model.transformer.modality import Modality
15
+ from ltx_trainer import logger
16
+ from ltx_trainer.timestep_samplers import TimestepSampler
17
+ from ltx_trainer.training_strategies.base_strategy import (
18
+ DEFAULT_FPS,
19
+ ModelInputs,
20
+ TrainingStrategy,
21
+ TrainingStrategyConfigBase,
22
+ )
23
+
24
+
25
+ class VideoToVideoConfig(TrainingStrategyConfigBase):
26
+ """Configuration for video-to-video (IC-LoRA) training strategy."""
27
+
28
+ name: Literal["video_to_video"] = "video_to_video"
29
+
30
+ first_frame_conditioning_p: float = Field(
31
+ default=0.1,
32
+ description="Probability of conditioning on the first frame during training",
33
+ ge=0.0,
34
+ le=1.0,
35
+ )
36
+
37
+ reference_latents_dir: str = Field(
38
+ default="reference_latents",
39
+ description="Directory name for latents of reference videos",
40
+ )
41
+
42
+
43
+ class VideoToVideoStrategy(TrainingStrategy):
44
+ """Video-to-video training strategy for IC-LoRA.
45
+ This strategy implements training with reference video conditioning where:
46
+ - Reference latents (clean) are concatenated with target latents (noised)
47
+ - Video coordinates handle both reference and target sequences
48
+ - Loss is computed only on the target portion
49
+ Attributes:
50
+ reference_downscale_factor: The inferred downscale factor of reference videos.
51
+ This is computed from the first batch and cached for metadata export.
52
+ """
53
+
54
+ config: VideoToVideoConfig
55
+ reference_downscale_factor: int | None
56
+
57
+ def __init__(self, config: VideoToVideoConfig):
58
+ """Initialize strategy with configuration.
59
+ Args:
60
+ config: Video-to-video configuration
61
+ """
62
+ super().__init__(config)
63
+ self.reference_downscale_factor = None # Will be inferred from first batch
64
+
65
+ def get_data_sources(self) -> dict[str, str]:
66
+ """IC-LoRA training requires latents, conditions, and reference latents."""
67
+ return {
68
+ "latents": "latents",
69
+ "conditions": "conditions",
70
+ self.config.reference_latents_dir: "ref_latents",
71
+ }
72
+
73
+ def prepare_training_inputs( # noqa: PLR0915
74
+ self,
75
+ batch: dict[str, Any],
76
+ timestep_sampler: TimestepSampler,
77
+ ) -> ModelInputs:
78
+ """Prepare inputs for IC-LoRA training with reference videos."""
79
+ # Get pre-encoded latents - dataset provides uniform non-patchified format [B, C, F, H, W]
80
+ latents = batch["latents"]
81
+ target_latents = latents["latents"]
82
+ ref_latents = batch["ref_latents"]["latents"]
83
+
84
+ # Get dimensions
85
+ num_frames = latents["num_frames"][0].item()
86
+ height = latents["height"][0].item()
87
+ width = latents["width"][0].item()
88
+
89
+ ref_latents_info = batch["ref_latents"]
90
+ ref_frames = ref_latents_info["num_frames"][0].item()
91
+ ref_height = ref_latents_info["height"][0].item()
92
+ ref_width = ref_latents_info["width"][0].item()
93
+
94
+ # Infer reference downscale factor from dimension ratios
95
+ # This allows training with downscaled reference videos for efficiency
96
+ reference_downscale_factor = self._infer_reference_downscale_factor(
97
+ target_height=height,
98
+ target_width=width,
99
+ ref_height=ref_height,
100
+ ref_width=ref_width,
101
+ )
102
+
103
+ # Cache the scale factor for metadata export (only on first batch)
104
+ if self.reference_downscale_factor is None:
105
+ self.reference_downscale_factor = reference_downscale_factor
106
+ elif self.reference_downscale_factor != reference_downscale_factor:
107
+ raise ValueError(
108
+ f"Inconsistent reference downscale factor across batches. "
109
+ f"First batch had factor={self.reference_downscale_factor}, "
110
+ f"but current batch has factor={reference_downscale_factor}. "
111
+ f"All training samples must use the same reference/target resolution ratio."
112
+ )
113
+
114
+ # Patchify latents: [B, C, F, H, W] -> [B, seq_len, C]
115
+ target_latents = self._video_patchifier.patchify(target_latents)
116
+ ref_latents = self._video_patchifier.patchify(ref_latents)
117
+
118
+ # Handle FPS
119
+ fps = latents.get("fps", None)
120
+ if fps is not None and not torch.all(fps == fps[0]):
121
+ logger.warning(
122
+ f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}"
123
+ )
124
+ fps = fps[0].item() if fps is not None else DEFAULT_FPS
125
+
126
+ # Get text embeddings (already processed by embedding connectors in trainer)
127
+ # Video-to-video uses only video embeddings
128
+ conditions = batch["conditions"]
129
+ prompt_embeds = conditions["video_prompt_embeds"]
130
+ prompt_attention_mask = conditions["prompt_attention_mask"]
131
+
132
+ batch_size = target_latents.shape[0]
133
+ ref_seq_len = ref_latents.shape[1]
134
+ target_seq_len = target_latents.shape[1]
135
+ device = target_latents.device
136
+ dtype = target_latents.dtype
137
+
138
+ # Create conditioning mask
139
+ # Reference tokens are always conditioning (timestep=0)
140
+ ref_conditioning_mask = torch.ones(batch_size, ref_seq_len, dtype=torch.bool, device=device)
141
+
142
+ # Target tokens: check for first frame conditioning
143
+ target_conditioning_mask = self._create_first_frame_conditioning_mask(
144
+ batch_size=batch_size,
145
+ sequence_length=target_seq_len,
146
+ height=height,
147
+ width=width,
148
+ device=device,
149
+ first_frame_conditioning_p=self.config.first_frame_conditioning_p,
150
+ )
151
+
152
+ # Combined conditioning mask
153
+ conditioning_mask = torch.cat([ref_conditioning_mask, target_conditioning_mask], dim=1)
154
+
155
+ # Sample noise and sigmas for target
156
+ sigmas = timestep_sampler.sample_for(target_latents)
157
+ noise = torch.randn_like(target_latents)
158
+ sigmas_expanded = sigmas.view(-1, 1, 1)
159
+
160
+ # Apply noise to target
161
+ noisy_target = (1 - sigmas_expanded) * target_latents + sigmas_expanded * noise
162
+
163
+ # For first frame conditioning in target, use clean latents
164
+ target_conditioning_mask_expanded = target_conditioning_mask.unsqueeze(-1)
165
+ noisy_target = torch.where(target_conditioning_mask_expanded, target_latents, noisy_target)
166
+
167
+ # Targets for loss computation
168
+ targets = noise - target_latents
169
+
170
+ # Concatenate reference (clean) and target (noisy)
171
+ combined_latents = torch.cat([ref_latents, noisy_target], dim=1)
172
+
173
+ # Create per-token timesteps
174
+ timesteps = self._create_per_token_timesteps(conditioning_mask, sigmas.squeeze())
175
+
176
+ # Generate positions for reference and target separately, then concatenate
177
+ ref_positions = self._get_video_positions(
178
+ num_frames=ref_frames,
179
+ height=ref_height,
180
+ width=ref_width,
181
+ batch_size=batch_size,
182
+ fps=fps,
183
+ device=device,
184
+ dtype=dtype,
185
+ )
186
+
187
+ # Scale reference positions to match target coordinate space
188
+ # This maps ref positions from (0, ref_H, ref_W) to (0, target_H, target_W)
189
+ # Position tensor shape: [B, 3, seq_len, 2] where dim 1 is (time, height, width)
190
+ if reference_downscale_factor != 1:
191
+ ref_positions = ref_positions.clone()
192
+ ref_positions[:, 1, ...] *= reference_downscale_factor # height axis
193
+ ref_positions[:, 2, ...] *= reference_downscale_factor # width axis
194
+ # Time axis (index 0) remains unchanged
195
+
196
+ target_positions = self._get_video_positions(
197
+ num_frames=num_frames,
198
+ height=height,
199
+ width=width,
200
+ batch_size=batch_size,
201
+ fps=fps,
202
+ device=device,
203
+ dtype=dtype,
204
+ )
205
+
206
+ # Concatenate positions along sequence dimension
207
+ positions = torch.cat([ref_positions, target_positions], dim=2)
208
+
209
+ # Create video Modality
210
+ video_modality = Modality(
211
+ enabled=True,
212
+ latent=combined_latents,
213
+ sigma=sigmas,
214
+ timesteps=timesteps,
215
+ positions=positions,
216
+ context=prompt_embeds,
217
+ context_mask=prompt_attention_mask,
218
+ )
219
+
220
+ # Loss mask: only compute loss on non-conditioning target tokens
221
+ # Reference tokens: all False (no loss)
222
+ # Target tokens: True where not conditioning
223
+ ref_loss_mask = torch.zeros(batch_size, ref_seq_len, dtype=torch.bool, device=device)
224
+ target_loss_mask = ~target_conditioning_mask
225
+ video_loss_mask = torch.cat([ref_loss_mask, target_loss_mask], dim=1)
226
+
227
+ return ModelInputs(
228
+ video=video_modality,
229
+ audio=None,
230
+ video_targets=targets,
231
+ audio_targets=None,
232
+ video_loss_mask=video_loss_mask,
233
+ audio_loss_mask=None,
234
+ ref_seq_len=ref_seq_len,
235
+ )
236
+
237
+ def compute_loss(
238
+ self,
239
+ video_pred: Tensor,
240
+ _audio_pred: Tensor | None,
241
+ inputs: ModelInputs,
242
+ ) -> Tensor:
243
+ """Compute masked loss only on target portion."""
244
+ # Extract target portion of prediction
245
+ ref_seq_len = inputs.ref_seq_len
246
+ target_pred = video_pred[:, ref_seq_len:, :]
247
+
248
+ # Get target portion of loss mask
249
+ target_loss_mask = inputs.video_loss_mask[:, ref_seq_len:]
250
+
251
+ # Compute loss
252
+ loss = (target_pred - inputs.video_targets).pow(2)
253
+
254
+ # Apply loss mask
255
+ loss_mask = target_loss_mask.unsqueeze(-1).float()
256
+ loss = loss.mul(loss_mask).div(loss_mask.mean())
257
+
258
+ return loss.mean()
259
+
260
+ def get_checkpoint_metadata(self) -> dict[str, Any]:
261
+ """Get metadata for checkpoint files."""
262
+ metadata: dict[str, Any] = {}
263
+ # Always include reference_downscale_factor for IC-LoRAs so inference
264
+ # pipelines know the expected scale factor for reference videos.
265
+ if self.reference_downscale_factor is not None:
266
+ metadata["reference_downscale_factor"] = self.reference_downscale_factor
267
+ return metadata
268
+
269
+ @staticmethod
270
+ def _infer_reference_downscale_factor(
271
+ target_height: int,
272
+ target_width: int,
273
+ ref_height: int,
274
+ ref_width: int,
275
+ ) -> int:
276
+ """Infer the reference downscale factor from target and reference dimensions."""
277
+ # If dimensions match, no scaling needed
278
+ if target_height == ref_height and target_width == ref_width:
279
+ return 1
280
+
281
+ # Calculate scale factors for each dimension
282
+ if target_height % ref_height != 0 or target_width % ref_width != 0:
283
+ raise ValueError(
284
+ f"Target dimensions ({target_height}x{target_width}) must be exact multiples "
285
+ f"of reference dimensions ({ref_height}x{ref_width})"
286
+ )
287
+
288
+ scale_h = target_height // ref_height
289
+ scale_w = target_width // ref_width
290
+
291
+ if scale_h != scale_w:
292
+ raise ValueError(
293
+ f"Reference scale must be uniform. Got height scale {scale_h} and width scale {scale_w}. "
294
+ f"Target: {target_height}x{target_width}, Reference: {ref_height}x{ref_width}"
295
+ )
296
+
297
+ if scale_h < 1:
298
+ raise ValueError(
299
+ f"Reference dimensions ({ref_height}x{ref_width}) cannot be larger than "
300
+ f"target dimensions ({target_height}x{target_width})"
301
+ )
302
+
303
+ return scale_h
packages/ltx-trainer/src/ltx_trainer/utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import ExifTags, Image, ImageCms, ImageOps
7
+ from PIL.Image import Image as PilImage
8
+
9
+
10
+ def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage:
11
+ """
12
+ Opens an image file, applies rotation (if it's set in metadata) and converts it
13
+ to the sRGB color space respecting the original image color space .
14
+ Args:
15
+ image_path: Path to the image file
16
+ Returns:
17
+ PIL Image in sRGB color space
18
+ """
19
+ exif_colorspace_srgb = 1
20
+
21
+ with Image.open(image_path) as img_raw:
22
+ img = ImageOps.exif_transpose(img_raw)
23
+
24
+ input_icc_profile = img.info.get("icc_profile")
25
+
26
+ # Try to convert to sRGB if the image has ICC profile metadata
27
+ srgb_profile = ImageCms.createProfile(colorSpace="sRGB")
28
+ if input_icc_profile is not None:
29
+ input_profile = ImageCms.ImageCmsProfile(io.BytesIO(input_icc_profile))
30
+ srgb_img = ImageCms.profileToProfile(img, input_profile, srgb_profile, outputMode="RGB")
31
+ else:
32
+ # Try fall back to checking EXIF
33
+ exif_data = img.getexif()
34
+ if exif_data is not None:
35
+ # Assume sRGB if no ICC profile and EXIF has no ColorSpace tag
36
+ color_space_value = exif_data.get(ExifTags.Base.ColorSpace.value)
37
+ if color_space_value is not None and color_space_value != exif_colorspace_srgb:
38
+ raise ValueError(
39
+ "Image has colorspace tag in EXIF but it isn't set to sRGB,"
40
+ " conversion is not supported."
41
+ f" EXIF ColorSpace tag value is {color_space_value}",
42
+ )
43
+
44
+ srgb_img = img.convert("RGB")
45
+
46
+ # Set sRGB profile in metadata since now the image is assumed to be in sRGB.
47
+ srgb_profile_data = ImageCms.ImageCmsProfile(srgb_profile).tobytes()
48
+ srgb_img.info["icc_profile"] = srgb_profile_data
49
+
50
+ return srgb_img
51
+
52
+
53
+ def save_image(image_tensor: torch.Tensor, output_path: Path | str) -> None:
54
+ """Save an image tensor to a file.
55
+ Args:
56
+ image_tensor: Image tensor of shape [C, H, W] or [C, 1, H, W] in range [0, 1] or [0, 255].
57
+ C must be 3 (RGB).
58
+ output_path: Path to save the image (any PIL-supported format, e.g., .png or .jpg)
59
+ """
60
+ output_path = Path(output_path)
61
+ output_path.parent.mkdir(parents=True, exist_ok=True)
62
+
63
+ # Handle [C, 1, H, W] format (single frame from video tensor)
64
+ if image_tensor.ndim == 4:
65
+ # Squeeze frame dimension: [C, 1, H, W] -> [C, H, W]
66
+ if image_tensor.shape[1] == 1:
67
+ image_tensor = image_tensor.squeeze(1)
68
+ else:
69
+ raise ValueError(f"Expected single-frame tensor with shape [C, 1, H, W], got shape {image_tensor.shape}")
70
+
71
+ if image_tensor.ndim != 3:
72
+ raise ValueError(f"Expected 3D tensor [C, H, W], got {image_tensor.ndim}D tensor")
73
+
74
+ if image_tensor.shape[0] != 3:
75
+ raise ValueError(f"Expected 3 channels (RGB), got {image_tensor.shape[0]} channels")
76
+
77
+ # Normalize to [0, 255] uint8
78
+ if torch.is_floating_point(image_tensor) and image_tensor.max() <= 1.0:
79
+ image_tensor = image_tensor * 255
80
+
81
+ # Clamp to valid uint8 range to prevent overflow
82
+ image_tensor = image_tensor.clamp(0, 255)
83
+
84
+ # [C, H, W] -> [H, W, C]
85
+ image_np: np.ndarray = image_tensor.permute(1, 2, 0).to(torch.uint8).cpu().numpy()
86
+
87
+ # Save using PIL
88
+ Image.fromarray(image_np).save(output_path)
packages/ltx-trainer/templates/model_card.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - ltx-2
4
+ - ltx-video
5
+ - text-to-video
6
+ - audio-video
7
+ pinned: true
8
+ language:
9
+ - en
10
+ license: other
11
+ pipeline_tag: text-to-video
12
+ library_name: diffusers
13
+ ---
14
+
15
+ # {model_name}
16
+
17
+ This is a fine-tuned version of [`{base_model}`]({base_model_link}) trained on custom data.
18
+
19
+ ## Model Details
20
+
21
+ - **Base Model:** [`{base_model}`]({base_model_link})
22
+ - **Training Type:** {training_type}
23
+ - **Training Steps:** {training_steps}
24
+ - **Learning Rate:** {learning_rate}
25
+ - **Batch Size:** {batch_size}
26
+
27
+ ## Sample Outputs
28
+
29
+ | | | | |
30
+ |:---:|:---:|:---:|:---:|
31
+ {sample_grid}
32
+
33
+ ## Usage
34
+
35
+ This model is designed to be used with the LTX-2 (Lightricks Audio-Video) pipeline.
36
+
37
+ ### 🔌 Using Trained LoRAs in ComfyUI
38
+
39
+ In order to use the trained LoRA in ComfyUI, follow these steps:
40
+
41
+ 1. Copy your trained LoRA checkpoint (`.safetensors` file) to the `models/loras` folder in your ComfyUI installation.
42
+ 2. In your ComfyUI workflow:
43
+ - Add the "Load LoRA" node to choose your LoRA file
44
+ - Connect it to the "Load Checkpoint" node to apply the LoRA to the base model
45
+
46
+ You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the
47
+ official [LTX-2 repository](https://github.com/Lightricks/LTX-2).
48
+
49
+ ### Example Prompts
50
+
51
+ {validation_prompts}
52
+
53
+
54
+ This model inherits the license of the base model ([`{base_model}`]({base_model_link})).
55
+
56
+ ## Acknowledgments
57
+
58
+ - Base model: [Lightricks](https://huggingface.co/Lightricks/LTX-2)
59
+ - Trainer: [LTX-2](https://github.com/Lightricks/LTX-2)