dagloop5 commited on
Commit
3a6dd90
·
verified ·
1 Parent(s): 9cedc73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1018 -0
app.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Installation and Setup
3
+ # =============================================================================
4
+ import os
5
+ import subprocess
6
+ import sys
7
+
8
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
9
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
10
+
11
+ subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
12
+
13
+ LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
+ LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
15
+ LTX_COMMIT = "ae855f8538843825f9015a419cf4ba5edaf5eec2"
16
+
17
+ if not os.path.exists(LTX_REPO_DIR):
18
+ print(f"Cloning {LTX_REPO_URL}...")
19
+ subprocess.run(["git", "clone", LTX_REPO_URL, LTX_REPO_DIR], check=True)
20
+ subprocess.run(["git", "checkout", LTX_COMMIT], cwd=LTX_REPO_DIR, check=True)
21
+
22
+ print("Installing ltx-core and ltx-pipelines from cloned repo...")
23
+ subprocess.run(
24
+ [sys.executable, "-m", "pip", "install", "--force-reinstall", "--no-deps", "-e",
25
+ os.path.join(LTX_REPO_DIR, "packages", "ltx-core"),
26
+ "-e", os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines")],
27
+ check=True,
28
+ )
29
+
30
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
31
+ sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
32
+
33
+ # =============================================================================
34
+ # Imports
35
+ # =============================================================================
36
+ import logging
37
+ import random
38
+ import tempfile
39
+ from pathlib import Path
40
+ import gc
41
+ import hashlib
42
+
43
+ import torch
44
+ torch._dynamo.config.suppress_errors = True
45
+ torch._dynamo.config.disable = True
46
+
47
+ import spaces
48
+ import gradio as gr
49
+ import numpy as np
50
+ from huggingface_hub import hf_hub_download, snapshot_download
51
+ from safetensors.torch import load_file, save_file
52
+
53
+ from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
54
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
55
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
56
+ from ltx_core.model.upsampler import upsample_video
57
+ from ltx_core.quantization import QuantizationPolicy
58
+ from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
59
+ from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
60
+ from ltx_core.components.noisers import GaussianNoiser
61
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
62
+ from ltx_core.components.schedulers import LTX2Scheduler
63
+ from ltx_core.types import Audio, LatentState, VideoPixelShape, AudioLatentShape
64
+ from ltx_core.tools import VideoLatentShape
65
+
66
+ from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
67
+ from ltx_pipelines.utils.args import ImageConditioningInput
68
+ from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS, STAGE_2_DISTILLED_SIGMA_VALUES
69
+ from ltx_pipelines.utils.media_io import encode_video
70
+ from ltx_pipelines.utils.helpers import (
71
+ assert_resolution,
72
+ cleanup_memory,
73
+ combined_image_conditionings,
74
+ encode_prompts,
75
+ multi_modal_guider_denoising_func,
76
+ simple_denoising_func,
77
+ denoise_audio_video,
78
+ )
79
+
80
+ from ltx_pipelines.utils import res2s_audio_video_denoising_loop
81
+
82
+ # Patch xformers
83
+ try:
84
+ from ltx_core.model.transformer import attention as _attn_mod
85
+ from xformers.ops import memory_efficient_attention as _mea
86
+ _attn_mod.memory_efficient_attention = _mea
87
+ print("[ATTN] xformers patch applied")
88
+ except Exception as e:
89
+ print(f"[ATTN] xformers patch failed: {e}")
90
+
91
+ logging.getLogger().setLevel(logging.INFO)
92
+
93
+ MAX_SEED = np.iinfo(np.int32).max
94
+ DEFAULT_PROMPT = (
95
+ "A majestic eagle soaring over mountain peaks at sunset, "
96
+ "wings spread wide against the orange sky, feathers catching the light, "
97
+ "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
98
+ )
99
+ DEFAULT_NEGATIVE_PROMPT = (
100
+ "worst quality, inconsistent motion, blurry, jittery, distorted, "
101
+ "deformed, artifacts, text, watermark, logo, frame, border, "
102
+ "low resolution, pixelated, unnatural, fake, CGI, cartoon"
103
+ )
104
+ DEFAULT_FRAME_RATE = 24.0
105
+ MIN_DIM, MAX_DIM, STEP = 256, 1280, 64
106
+ MIN_FRAMES, MAX_FRAMES = 9, 721
107
+
108
+ # Resolution presets with high/low tiers
109
+ RESOLUTIONS = {
110
+ "high": {"16:9": (1536, 1024), "9:16": (1024, 1536), "1:1": (1024, 1024)},
111
+ "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768)},
112
+ }
113
+
114
+ LTX_MODEL_REPO = "Lightricks/LTX-2.3"
115
+ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
116
+
117
+ # =============================================================================
118
+ # Custom HQ Pipeline with LoRA Cache Support
119
+ # =============================================================================
120
+
121
+ class HQPipelineWithCachedLoRA:
122
+ """
123
+ Custom HQ pipeline that:
124
+ 1. Creates ONE ModelLedger WITHOUT LoRAs
125
+ 2. Handles ALL LoRAs via cached state (distilled + 12 custom)
126
+ 3. Supports CFG/negative prompts and guidance parameters
127
+ 4. Reuses single transformer for both stages
128
+ 5. Uses 8 steps at half resolution + 3 steps at full resolution
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ checkpoint_path: str,
134
+ spatial_upsampler_path: str,
135
+ gemma_root: str,
136
+ quantization: QuantizationPolicy | None = None,
137
+ ):
138
+ from ltx_pipelines.utils import ModelLedger
139
+ from ltx_pipelines.utils.types import PipelineComponents
140
+
141
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
+ self.dtype = torch.bfloat16
143
+
144
+ print(" Creating ModelLedger (no LoRAs)...")
145
+ self.model_ledger = ModelLedger(
146
+ dtype=self.dtype,
147
+ device=self.device,
148
+ checkpoint_path=checkpoint_path,
149
+ gemma_root_path=gemma_root,
150
+ spatial_upsampler_path=spatial_upsampler_path,
151
+ loras=(),
152
+ quantization=quantization,
153
+ )
154
+
155
+ self.pipeline_components = PipelineComponents(
156
+ dtype=self.dtype,
157
+ device=self.device,
158
+ )
159
+
160
+ self._cached_state = None
161
+
162
+ def apply_cached_lora_state(self, state_dict):
163
+ """Apply pre-cached LoRA state to transformer."""
164
+ self._cached_state = state_dict
165
+
166
+ @torch.inference_mode()
167
+ def __call__( # noqa: PLR0913
168
+ self,
169
+ prompt: str,
170
+ negative_prompt: str,
171
+ seed: int,
172
+ height: int,
173
+ width: int,
174
+ num_frames: int,
175
+ frame_rate: float,
176
+ video_guider_params: MultiModalGuiderParams,
177
+ audio_guider_params: MultiModalGuiderParams,
178
+ images: list,
179
+ tiling_config: TilingConfig | None = None,
180
+ ):
181
+ from ltx_pipelines.utils import assert_resolution, cleanup_memory, combined_image_conditionings, encode_prompts, res2s_audio_video_denoising_loop, multi_modal_guider_denoising_func, simple_denoising_func, denoise_audio_video
182
+ from ltx_core.tools import VideoLatentShape
183
+ from ltx_core.components.noisers import GaussianNoiser
184
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
185
+ from ltx_core.types import VideoPixelShape
186
+ from ltx_core.model.upsampler import upsample_video
187
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
188
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
189
+
190
+ assert_resolution(height=height, width=width, is_two_stage=True)
191
+
192
+ device = self.device
193
+ dtype = self.dtype
194
+ generator = torch.Generator(device=device).manual_seed(seed)
195
+ noiser = GaussianNoiser(generator=generator)
196
+
197
+ # NO LoRA application here - done in apply_prepared_lora_state_to_pipeline()
198
+
199
+ ctx_p, ctx_n = encode_prompts(
200
+ [prompt, negative_prompt],
201
+ self.model_ledger,
202
+ )
203
+
204
+ v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
205
+ v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
206
+
207
+ # ===================== STAGE 1: 8 steps at half resolution =====================
208
+ stage_1_output_shape = VideoPixelShape(
209
+ batch=1, frames=num_frames,
210
+ width=width // 2, height=height // 2, fps=frame_rate
211
+ )
212
+
213
+ video_encoder = self.model_ledger.video_encoder()
214
+ stage_1_conditionings = combined_image_conditionings(
215
+ images=images,
216
+ height=stage_1_output_shape.height,
217
+ width=stage_1_output_shape.width,
218
+ video_encoder=video_encoder,
219
+ dtype=dtype,
220
+ device=device,
221
+ )
222
+ torch.cuda.synchronize()
223
+ del video_encoder
224
+ cleanup_memory()
225
+
226
+ transformer = self.model_ledger.transformer()
227
+
228
+ # Use DISTILLED_SIGMA_VALUES for 8 steps at half resolution
229
+ from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES
230
+ stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=device)
231
+ stepper = Res2sDiffusionStep()
232
+
233
+ def first_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
234
+ return res2s_audio_video_denoising_loop(
235
+ sigmas=sigmas,
236
+ video_state=video_state,
237
+ audio_state=audio_state,
238
+ stepper=stepper,
239
+ denoise_fn=multi_modal_guider_denoising_func(
240
+ video_guider=MultiModalGuider(params=video_guider_params, negative_context=v_context_n),
241
+ audio_guider=MultiModalGuider(params=audio_guider_params, negative_context=a_context_n),
242
+ v_context=v_context_p,
243
+ a_context=a_context_p,
244
+ transformer=transformer,
245
+ ),
246
+ )
247
+
248
+ video_state, audio_state = denoise_audio_video(
249
+ output_shape=stage_1_output_shape,
250
+ conditionings=stage_1_conditionings,
251
+ noiser=noiser,
252
+ sigmas=stage_1_sigmas,
253
+ stepper=stepper,
254
+ denoising_loop_fn=first_stage_denoising_loop,
255
+ components=self.pipeline_components,
256
+ dtype=dtype,
257
+ device=device,
258
+ )
259
+
260
+ torch.cuda.synchronize()
261
+ del transformer
262
+ cleanup_memory()
263
+
264
+ # ===================== UPSCALING =====================
265
+ video_encoder = self.model_ledger.video_encoder()
266
+ upscaled_video_latent = upsample_video(
267
+ latent=video_state.latent[:1],
268
+ video_encoder=video_encoder,
269
+ upsampler=self.model_ledger.spatial_upsampler(),
270
+ )
271
+
272
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
273
+ stage_2_conditionings = combined_image_conditionings(
274
+ images=images,
275
+ height=stage_2_output_shape.height,
276
+ width=stage_2_output_shape.width,
277
+ video_encoder=video_encoder,
278
+ dtype=dtype,
279
+ device=device,
280
+ )
281
+ torch.cuda.synchronize()
282
+ del video_encoder
283
+ cleanup_memory()
284
+
285
+ # ===================== STAGE 2: 3 steps at full resolution =====================
286
+ transformer = self.model_ledger.transformer()
287
+
288
+ from ltx_pipelines.utils.constants import STAGE_2_DISTILLED_SIGMA_VALUES
289
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
290
+
291
+ def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
292
+ return res2s_audio_video_denoising_loop(
293
+ sigmas=sigmas,
294
+ video_state=video_state,
295
+ audio_state=audio_state,
296
+ stepper=stepper,
297
+ denoise_fn=simple_denoising_func(
298
+ video_context=v_context_p,
299
+ audio_context=a_context_p,
300
+ transformer=transformer,
301
+ ),
302
+ )
303
+
304
+ video_state, audio_state = denoise_audio_video(
305
+ output_shape=stage_2_output_shape,
306
+ conditionings=stage_2_conditionings,
307
+ noiser=noiser,
308
+ sigmas=stage_2_sigmas,
309
+ stepper=stepper,
310
+ denoising_loop_fn=second_stage_denoising_loop,
311
+ components=self.pipeline_components,
312
+ dtype=dtype,
313
+ device=device,
314
+ noise_scale=stage_2_sigmas[0],
315
+ initial_video_latent=upscaled_video_latent,
316
+ initial_audio_latent=audio_state.latent,
317
+ )
318
+
319
+ torch.cuda.synchronize()
320
+ del transformer
321
+ cleanup_memory()
322
+
323
+ # ===================== DECODE =====================
324
+ decoded_video = vae_decode_video(
325
+ video_state.latent, self.model_ledger.video_decoder(), tiling_config, generator
326
+ )
327
+ decoded_audio = vae_decode_audio(
328
+ audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()
329
+ )
330
+
331
+ return decoded_video, decoded_audio
332
+
333
+
334
+ # =============================================================================
335
+ # Model Download
336
+ # =============================================================================
337
+
338
+ print("=" * 80)
339
+ print("Downloading LTX-2.3 HQ models...")
340
+ print("=" * 80)
341
+
342
+ weights_dir = Path("weights")
343
+ weights_dir.mkdir(exist_ok=True)
344
+
345
+ checkpoint_path = hf_hub_download(
346
+ repo_id=LTX_MODEL_REPO,
347
+ filename="ltx-2.3-22b-dev.safetensors",
348
+ local_dir=str(weights_dir),
349
+ local_dir_use_symlinks=False, # Ensure actual file copy, not symlink
350
+ )
351
+
352
+ # Force download if not present
353
+ if not os.path.exists(checkpoint_path):
354
+ print(f"Re-downloading checkpoint to {weights_dir}...")
355
+ checkpoint_path = hf_hub_download(
356
+ repo_id=LTX_MODEL_REPO,
357
+ filename="ltx-2.3-22b-distilled-1.1.safetensors",
358
+ local_dir=str(weights_dir),
359
+ local_dir_use_symlinks=False,
360
+ force_download=True,
361
+ )
362
+
363
+ print(f"Checkpoint at: {checkpoint_path}")
364
+ print(f"File exists: {os.path.exists(checkpoint_path)}")
365
+ print(f"File size: {os.path.getsize(checkpoint_path) / 1024**3:.2f} GB")
366
+
367
+ spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
368
+ distilled_lora_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled-lora-384.safetensors")
369
+ gemma_root = snapshot_download(repo_id=GEMMA_REPO)
370
+
371
+ print(f"Dev checkpoint: {checkpoint_path}")
372
+ print(f"Spatial upsampler: {spatial_upsampler_path}")
373
+ print(f"Distilled LoRA: {distilled_lora_path}")
374
+ print(f"Gemma root: {gemma_root}")
375
+
376
+ # =============================================================================
377
+ # Download Custom LoRAs
378
+ # =============================================================================
379
+
380
+ LORA_REPO = "dagloop5/LoRA"
381
+
382
+ print("=" * 80)
383
+ print("Downloading custom LoRA adapters...")
384
+ print("=" * 80)
385
+
386
+ pose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2_3_NSFW_furry_concat_v2.safetensors")
387
+ general_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_reasoning_I2V_V3.safetensors")
388
+ motion_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="motion_helper.safetensors")
389
+ dreamlay_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="DR34ML4Y_LTXXX_PREVIEW_RC1.safetensors")
390
+ mself_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="Furry Hyper Masturbation - LTX-2 I2V v1.safetensors")
391
+ dramatic_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2.3 - Orgasm.safetensors")
392
+ fluid_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="cr3ampi3_animation_i2v_ltx2_v1.0.safetensors")
393
+ liquid_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="liquid_wet_dr1pp_ltx2_v1.0_scaled.safetensors")
394
+ demopose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="clapping-cheeks-audio-v001-alpha.safetensors")
395
+ voice_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="hentai_voice_ltx23.safetensors")
396
+ realism_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="FurryenhancerLTX2.3V1.215.safetensors")
397
+ transition_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2_takerpov_lora_v1.2.safetensors")
398
+
399
+ print(f"All 12 custom LoRAs downloaded + distilled LoRA")
400
+ print("=" * 80)
401
+
402
+ # =============================================================================
403
+ # Pipeline Initialization
404
+ # =============================================================================
405
+
406
+ print("Initializing HQ Pipeline...")
407
+
408
+ pipeline = HQPipelineWithCachedLoRA(
409
+ checkpoint_path=checkpoint_path,
410
+ spatial_upsampler_path=spatial_upsampler_path,
411
+ gemma_root=gemma_root,
412
+ quantization=QuantizationPolicy.fp8_cast(),
413
+ )
414
+
415
+ print("Pipeline initialized!")
416
+ print("=" * 80)
417
+
418
+ # =============================================================================
419
+ # ZeroGPU Tensor Preloading - Single Transformer
420
+ # =============================================================================
421
+
422
+ print("Preloading models for ZeroGPU tensor packing...")
423
+
424
+ # Load shared components
425
+ _video_encoder = pipeline.model_ledger.video_encoder()
426
+ _video_decoder = pipeline.model_ledger.video_decoder()
427
+ _audio_encoder = pipeline.model_ledger.audio_encoder()
428
+ _audio_decoder = pipeline.model_ledger.audio_decoder()
429
+ _vocoder = pipeline.model_ledger.vocoder()
430
+ _spatial_upsampler = pipeline.model_ledger.spatial_upsampler()
431
+ _text_encoder = pipeline.model_ledger.text_encoder()
432
+ _embeddings_processor = pipeline.model_ledger.gemma_embeddings_processor()
433
+
434
+ # Load the SINGLE transformer
435
+ _transformer = pipeline.model_ledger.transformer()
436
+
437
+ # Replace ledger methods with lambdas returning cached instances
438
+ pipeline.model_ledger.video_encoder = lambda: _video_encoder
439
+ pipeline.model_ledger.video_decoder = lambda: _video_decoder
440
+ pipeline.model_ledger.audio_encoder = lambda: _audio_encoder
441
+ pipeline.model_ledger.audio_decoder = lambda: _audio_decoder
442
+ pipeline.model_ledger.vocoder = lambda: _vocoder
443
+ pipeline.model_ledger.spatial_upsampler = lambda: _spatial_upsampler
444
+ pipeline.model_ledger.text_encoder = lambda: _text_encoder
445
+ pipeline.model_ledger.gemma_embeddings_processor = lambda: _embeddings_processor
446
+ pipeline.model_ledger.transformer = lambda: _transformer
447
+
448
+ print("All models preloaded for ZeroGPU tensor packing!")
449
+ print("=" * 80)
450
+ print("Pipeline ready!")
451
+ print("=" * 80)
452
+
453
+ # =============================================================================
454
+ # LoRA Cache Functions
455
+ # =============================================================================
456
+
457
+ LORA_CACHE_DIR = Path("lora_cache")
458
+ LORA_CACHE_DIR.mkdir(exist_ok=True)
459
+
460
+ def prepare_lora_cache(
461
+ distilled_strength: float,
462
+ pose_strength: float, general_strength: float, motion_strength: float,
463
+ dreamlay_strength: float, mself_strength: float, dramatic_strength: float,
464
+ fluid_strength: float, liquid_strength: float, demopose_strength: float,
465
+ voice_strength: float, realism_strength: float, transition_strength: float,
466
+ progress=gr.Progress(track_tqdm=True),
467
+ ):
468
+ """Build cached LoRA state for single transformer."""
469
+ global pipeline
470
+
471
+ print("[LoRA] === Starting LoRA Cache Preparation ===")
472
+ progress(0.05, desc="Preparing LoRA cache...")
473
+
474
+ # Validate all LoRA files exist
475
+ print("[LoRA] Validating LoRA file paths...")
476
+ lora_files = [
477
+ ("Distilled", distilled_lora_path, distilled_strength),
478
+ ("Pose", pose_lora_path, pose_strength),
479
+ ("General", general_lora_path, general_strength),
480
+ ("Motion", motion_lora_path, motion_strength),
481
+ ("Dreamlay", dreamlay_lora_path, dreamlay_strength),
482
+ ("Mself", mself_lora_path, mself_strength),
483
+ ("Dramatic", dramatic_lora_path, dramatic_strength),
484
+ ("Fluid", fluid_lora_path, fluid_strength),
485
+ ("Liquid", liquid_lora_path, liquid_strength),
486
+ ("Demopose", demopose_lora_path, demopose_strength),
487
+ ("Voice", voice_lora_path, voice_strength),
488
+ ("Realism", realism_lora_path, realism_strength),
489
+ ("Transition", transition_lora_path, transition_strength),
490
+ ]
491
+
492
+ active_loras = []
493
+ for name, path, strength in lora_files:
494
+ if path is not None and float(strength) != 0.0:
495
+ active_loras.append((name, path, strength))
496
+ print(f"[LoRA] - {name}: strength={strength}")
497
+
498
+ print(f"[LoRA] Active LoRAs: {len(active_loras)}")
499
+
500
+ key_str = f"{checkpoint_path}:{distilled_strength}:{pose_strength}:{general_strength}:{motion_strength}:{dreamlay_strength}:{mself_strength}:{dramatic_strength}:{fluid_strength}:{liquid_strength}:{demopose_strength}:{voice_strength}:{realism_strength}:{transition_strength}"
501
+ key = hashlib.sha256(key_str.encode()).hexdigest()
502
+
503
+ cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
504
+ print(f"[LoRA] Cache key: {key[:16]}...")
505
+ print(f"[LoRA] Cache path: {cache_path}")
506
+
507
+ if cache_path.exists():
508
+ print("[LoRA] Loading from existing cache...")
509
+ progress(0.20, desc="Loading cached LoRA state...")
510
+ state = load_file(str(cache_path))
511
+ print(f"[LoRA] Loaded state dict with {len(state)} keys, size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
512
+ pipeline.apply_cached_lora_state(state)
513
+ print("[LoRA] State applied to pipeline._cached_state")
514
+ print("[LoRA] === LoRA Cache Preparation Complete ===")
515
+ return f"Loaded cached LoRA state: {cache_path.name} ({len(state)} keys)"
516
+
517
+ if not active_loras:
518
+ print("[LoRA] No non-zero LoRA strengths selected; nothing to prepare.")
519
+ print("[LoRA] === LoRA Cache Preparation Complete (no LoRAs) ===")
520
+ return "No non-zero LoRA strengths selected; nothing to prepare."
521
+
522
+ entries = [
523
+ (distilled_lora_path, distilled_strength),
524
+ (pose_lora_path, pose_strength),
525
+ (general_lora_path, general_strength),
526
+ (motion_lora_path, motion_strength),
527
+ (dreamlay_lora_path, dreamlay_strength),
528
+ (mself_lora_path, mself_strength),
529
+ (dramatic_lora_path, dramatic_strength),
530
+ (fluid_lora_path, fluid_strength),
531
+ (liquid_lora_path, liquid_strength),
532
+ (demopose_lora_path, demopose_strength),
533
+ (voice_lora_path, voice_strength),
534
+ (realism_lora_path, realism_strength),
535
+ (transition_lora_path, transition_strength),
536
+ ]
537
+
538
+ loras_for_builder = [
539
+ LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
540
+ for path, strength in entries
541
+ if path is not None and float(strength) != 0.0
542
+ ]
543
+
544
+ print(f"[LoRA] Building fused state on CPU with {len(loras_for_builder)} LoRAs...")
545
+ print("[LoRA] This may take several minutes (do not close the Space)...")
546
+ progress(0.35, desc="Building fused state (CPU)...")
547
+
548
+ import time
549
+ start_time = time.time()
550
+
551
+ tmp_ledger = pipeline.model_ledger.__class__(
552
+ dtype=torch.bfloat16,
553
+ device=torch.device("cpu"),
554
+ checkpoint_path=str(checkpoint_path),
555
+ spatial_upsampler_path=str(spatial_upsampler_path),
556
+ gemma_root_path=str(gemma_root),
557
+ loras=tuple(loras_for_builder),
558
+ quantization=None,
559
+ )
560
+ print(f"[LoRA] Temporary ledger created in {time.time() - start_time:.1f}s")
561
+
562
+ print("[LoRA] Loading transformer with LoRAs applied...")
563
+ transformer = tmp_ledger.transformer()
564
+ print(f"[LoRA] Transformer loaded in {time.time() - start_time:.1f}s")
565
+
566
+ print("[LoRA] Extracting state dict...")
567
+ progress(0.70, desc="Extracting fused stateDict")
568
+ state = {k: v.detach().cpu().contiguous() for k, v in transformer.state_dict().items()}
569
+ print(f"[LoRA] State dict extracted: {len(state)} keys")
570
+
571
+ print(f"[LoRA] Saving to cache: {cache_path}")
572
+ save_file(state, str(cache_path))
573
+ print(f"[LoRA] Cache saved, size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
574
+
575
+ print("[LoRA] Cleaning up temporary ledger...")
576
+ del transformer, tmp_ledger
577
+ gc.collect()
578
+ print(f"[LoRA] Cleanup complete in {time.time() - start_time:.1f}s total")
579
+
580
+ print("[LoRA] Applying state to pipeline._cached_state...")
581
+ progress(0.90, desc="Applying LoRA state to pipeline...")
582
+ pipeline.apply_cached_lora_state(state)
583
+
584
+ progress(1.0, desc="Done!")
585
+ print("[LoRA] === LoRA Cache Preparation Complete ===")
586
+ return f"Built and cached LoRA state: {cache_path.name} ({len(state)} keys, {time.time() - start_time:.1f}s)"
587
+
588
+ # =============================================================================
589
+ # LoRA State Application (called BEFORE pipeline generation)
590
+ # =============================================================================
591
+
592
+ def apply_prepared_lora_state_to_pipeline():
593
+ """
594
+ Apply the prepared LoRA state from pipeline._cached_state to the preloaded
595
+ transformer. This should be called BEFORE pipeline generation, not during.
596
+ """
597
+ print("[LoRA] === Applying LoRA State to Transformer ===")
598
+
599
+ if pipeline._cached_state is None:
600
+ print("[LoRA] No prepared LoRA state available; skipping.")
601
+ print("[LoRA] === LoRA Application Complete (no state) ===")
602
+ return False
603
+
604
+ try:
605
+ existing_transformer = _transformer # The preloaded transformer from globals
606
+ state = pipeline._cached_state
607
+ print(f"[LoRA] Applying state dict with {len(state)} keys...")
608
+ print(f"[LoRA] State dict size: {sum(v.numel() * v.element_size() for v in state.values()) / 1024**3:.2f} GB")
609
+
610
+ import time
611
+ start_time = time.time()
612
+
613
+ with torch.no_grad():
614
+ missing, unexpected = existing_transformer.load_state_dict(state, strict=False)
615
+
616
+ print(f"[LoRA] load_state_dict completed in {time.time() - start_time:.1f}s")
617
+
618
+ if missing:
619
+ print(f"[LoRA] WARNING: {len(missing)} keys missing from state dict")
620
+ if unexpected:
621
+ print(f"[LoRA] WARNING: {len(unexpected)} unexpected keys in state dict")
622
+
623
+ if not missing and not unexpected:
624
+ print("[LoRA] State dict loaded successfully with no mismatches!")
625
+
626
+ print("[LoRA] === LoRA Application Complete (success) ===")
627
+ return True
628
+ except Exception as e:
629
+ print(f"[LoRA] FAILED to apply LoRA state: {type(e).__name__}: {e}")
630
+ print("[LoRA] === LoRA Application Complete (FAILED) ===")
631
+ return False
632
+
633
+ # =============================================================================
634
+ # Helper Functions
635
+ # =============================================================================
636
+
637
+ def log_memory(tag: str):
638
+ if torch.cuda.is_available():
639
+ allocated = torch.cuda.memory_allocated() / 1024**3
640
+ peak = torch.cuda.max_memory_allocated() / 1024**3
641
+ free, total = torch.cuda.mem_get_info()
642
+ print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
643
+
644
+
645
+ def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
646
+ ideal_frames = int(duration * frame_rate)
647
+ ideal_frames = max(ideal_frames, MIN_FRAMES)
648
+ k = round((ideal_frames - 1) / 8)
649
+ frames = k * 8 + 1
650
+ return min(frames, MAX_FRAMES)
651
+
652
+ def detect_aspect_ratio(image) -> str:
653
+ if image is None:
654
+ return "16:9"
655
+ if hasattr(image, "size"):
656
+ w, h = image.size
657
+ elif hasattr(image, "shape"):
658
+ h, w = image.shape[:2]
659
+ else:
660
+ return "16:9"
661
+ ratio = w / h
662
+ candidates = {"16:9": 16 / 9, "9:16": 9 / 16, "1:1": 1.0}
663
+ return min(candidates, key=lambda k: abs(ratio - candidates[k]))
664
+
665
+ def on_image_upload(first_image, last_image, high_res):
666
+ ref_image = first_image if first_image is not None else last_image
667
+ aspect = detect_aspect_ratio(ref_image)
668
+ tier = "high" if high_res else "low"
669
+ w, h = RESOLUTIONS[tier][aspect]
670
+ return gr.update(value=w), gr.update(value=h)
671
+
672
+
673
+ def on_highres_toggle(first_image, last_image, high_res):
674
+ ref_image = first_image if first_image is not None else last_image
675
+ aspect = detect_aspect_ratio(ref_image)
676
+ tier = "high" if high_res else "low"
677
+ w, h = RESOLUTIONS[tier][aspect]
678
+ return gr.update(value=w), gr.update(value=h)
679
+
680
+
681
+ def get_gpu_duration(
682
+ first_image,
683
+ last_image,
684
+ prompt: str,
685
+ negative_prompt: str,
686
+ duration: float,
687
+ gpu_duration: float,
688
+ seed: int = 42,
689
+ randomize_seed: bool = True,
690
+ height: int = 1024,
691
+ width: int = 1536,
692
+ video_cfg_scale: float = 1.0,
693
+ video_stg_scale: float = 0.0,
694
+ video_rescale_scale: float = 0.45,
695
+ video_a2v_scale: float = 3.0,
696
+ audio_cfg_scale: float = 1.0,
697
+ audio_stg_scale: float = 0.0,
698
+ audio_rescale_scale: float = 1.0,
699
+ audio_v2a_scale: float = 3.0,
700
+ distilled_strength: float = 0.0,
701
+ pose_strength: float = 0.0,
702
+ general_strength: float = 0.0,
703
+ motion_strength: float = 0.0,
704
+ dreamlay_strength: float = 0.0,
705
+ mself_strength: float = 0.0,
706
+ dramatic_strength: float = 0.0,
707
+ fluid_strength: float = 0.0,
708
+ liquid_strength: float = 0.0,
709
+ demopose_strength: float = 0.0,
710
+ voice_strength: float = 0.0,
711
+ realism_strength: float = 0.0,
712
+ transition_strength: float = 0.0,
713
+ progress=None,
714
+ ) -> int:
715
+ return int(gpu_duration)
716
+
717
+
718
+ @spaces.GPU(duration=get_gpu_duration)
719
+ @torch.inference_mode()
720
+ def generate_video(
721
+ first_image,
722
+ last_image,
723
+ prompt: str,
724
+ negative_prompt: str,
725
+ duration: float,
726
+ gpu_duration: float,
727
+ seed: int = 42,
728
+ randomize_seed: bool = True,
729
+ height: int = 1024,
730
+ width: int = 1536,
731
+ video_cfg_scale: float = 1.0,
732
+ video_stg_scale: float = 0.0,
733
+ video_rescale_scale: float = 0.45,
734
+ video_a2v_scale: float = 3.0,
735
+ audio_cfg_scale: float = 1.0,
736
+ audio_stg_scale: float = 0.0,
737
+ audio_rescale_scale: float = 1.0,
738
+ audio_v2a_scale: float = 3.0,
739
+ distilled_strength: float = 0.0,
740
+ pose_strength: float = 0.0,
741
+ general_strength: float = 0.0,
742
+ motion_strength: float = 0.0,
743
+ dreamlay_strength: float = 0.0,
744
+ mself_strength: float = 0.0,
745
+ dramatic_strength: float = 0.0,
746
+ fluid_strength: float = 0.0,
747
+ liquid_strength: float = 0.0,
748
+ demopose_strength: float = 0.0,
749
+ voice_strength: float = 0.0,
750
+ realism_strength: float = 0.0,
751
+ transition_strength: float = 0.0,
752
+ progress=gr.Progress(track_tqdm=True),
753
+ ):
754
+ try:
755
+ torch.cuda.reset_peak_memory_stats()
756
+ log_memory("start")
757
+
758
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
759
+ print(f"Using seed: {current_seed}")
760
+
761
+ print(f"Resolution: {width}x{height}")
762
+
763
+ num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
764
+ print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
765
+
766
+ images = []
767
+ output_dir = Path("outputs")
768
+ output_dir.mkdir(exist_ok=True)
769
+
770
+ if first_image is not None:
771
+ temp_first_path = output_dir / f"temp_first_{current_seed}.jpg"
772
+ if hasattr(first_image, "save"):
773
+ first_image.save(temp_first_path)
774
+ else:
775
+ import shutil
776
+ shutil.copy(first_image, temp_first_path)
777
+ images.append(ImageConditioningInput(path=str(temp_first_path), frame_idx=0, strength=1.0))
778
+
779
+ if last_image is not None:
780
+ temp_last_path = output_dir / f"temp_last_{current_seed}.jpg"
781
+ if hasattr(last_image, "save"):
782
+ last_image.save(temp_last_path)
783
+ else:
784
+ import shutil
785
+ shutil.copy(last_image, temp_last_path)
786
+ images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
787
+
788
+ tiling_config = TilingConfig.default()
789
+ video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
790
+
791
+ video_guider_params = MultiModalGuiderParams(
792
+ cfg_scale=video_cfg_scale,
793
+ stg_scale=video_stg_scale,
794
+ rescale_scale=video_rescale_scale,
795
+ modality_scale=video_a2v_scale,
796
+ skip_step=0,
797
+ stg_blocks=[],
798
+ )
799
+
800
+ audio_guider_params = MultiModalGuiderParams(
801
+ cfg_scale=audio_cfg_scale,
802
+ stg_scale=audio_stg_scale,
803
+ rescale_scale=audio_rescale_scale,
804
+ modality_scale=audio_v2a_scale,
805
+ skip_step=0,
806
+ stg_blocks=[],
807
+ )
808
+
809
+ log_memory("before pipeline call")
810
+
811
+ apply_prepared_lora_state_to_pipeline()
812
+
813
+ video, audio = pipeline(
814
+ prompt=prompt,
815
+ negative_prompt=negative_prompt,
816
+ seed=current_seed,
817
+ height=height,
818
+ width=width,
819
+ num_frames=num_frames,
820
+ frame_rate=DEFAULT_FRAME_RATE,
821
+ video_guider_params=video_guider_params,
822
+ audio_guider_params=audio_guider_params,
823
+ images=images,
824
+ tiling_config=tiling_config,
825
+ )
826
+
827
+ log_memory("after pipeline call")
828
+
829
+ output_path = tempfile.mktemp(suffix=".mp4")
830
+ encode_video(
831
+ video=video,
832
+ fps=DEFAULT_FRAME_RATE,
833
+ audio=audio,
834
+ output_path=output_path,
835
+ video_chunks_number=video_chunks_number,
836
+ )
837
+
838
+ log_memory("after encode_video")
839
+ return str(output_path), current_seed
840
+
841
+ except Exception as e:
842
+ import traceback
843
+ log_memory("on error")
844
+ print(f"Error: {str(e)}\n{traceback.format_exc()}")
845
+ return None, current_seed
846
+
847
+
848
+ # =============================================================================
849
+ # Gradio UI
850
+ # =============================================================================
851
+
852
+ css = """
853
+ .fillable {max-width: 1200px !important}
854
+ .progress-text {color: black}
855
+ """
856
+
857
+ with gr.Blocks(title="LTX-2.3 Two-Stage HQ with LoRA Cache") as demo:
858
+ gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation with LoRA Cache")
859
+ gr.Markdown(
860
+ "High-quality text/image-to-video with cached LoRA state + CFG guidance. "
861
+ "[[Model]](https://huggingface.co/Lightricks/LTX-2.3)"
862
+ )
863
+
864
+ with gr.Row():
865
+ # LEFT SIDE: Input Controls
866
+ with gr.Column():
867
+ with gr.Row():
868
+ first_image = gr.Image(label="First Frame (Optional)", type="pil")
869
+ last_image = gr.Image(label="Last Frame (Optional)", type="pil")
870
+
871
+ prompt = gr.Textbox(
872
+ label="Prompt",
873
+ value=DEFAULT_PROMPT,
874
+ lines=3,
875
+ )
876
+
877
+ negative_prompt = gr.Textbox(
878
+ label="Negative Prompt",
879
+ value=DEFAULT_NEGATIVE_PROMPT,
880
+ lines=2,
881
+ )
882
+
883
+ duration = gr.Slider(
884
+ label="Duration (seconds)",
885
+ minimum=1.0, maximum=30.0, value=10.0, step=0.1,
886
+ )
887
+
888
+ with gr.Row():
889
+ seed = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=MAX_SEED)
890
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
891
+
892
+ with gr.Row():
893
+ high_res = gr.Checkbox(label="High Resolution", value=True)
894
+
895
+ with gr.Row():
896
+ width = gr.Number(label="Width", value=1536, precision=0)
897
+ height = gr.Number(label="Height", value=1024, precision=0)
898
+
899
+ generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
900
+
901
+ with gr.Accordion("Advanced Settings", open=False):
902
+ gr.Markdown("### Video Guidance Parameters")
903
+
904
+ with gr.Row():
905
+ video_cfg_scale = gr.Slider(
906
+ label="Video CFG Scale", minimum=1.0, maximum=10.0,
907
+ value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale, step=0.1
908
+ )
909
+ video_stg_scale = gr.Slider(
910
+ label="Video STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1
911
+ )
912
+
913
+ with gr.Row():
914
+ video_rescale_scale = gr.Slider(
915
+ label="Video Rescale", minimum=0.0, maximum=2.0, value=0.45, step=0.1
916
+ )
917
+ video_a2v_scale = gr.Slider(
918
+ label="A2V Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1
919
+ )
920
+
921
+ gr.Markdown("### Audio Guidance Parameters")
922
+
923
+ with gr.Row():
924
+ audio_cfg_scale = gr.Slider(
925
+ label="Audio CFG Scale", minimum=1.0, maximum=15.0,
926
+ value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale, step=0.1
927
+ )
928
+ audio_stg_scale = gr.Slider(
929
+ label="Audio STG Scale", minimum=0.0, maximum=2.0, value=0.0, step=0.1
930
+ )
931
+
932
+ with gr.Row():
933
+ audio_rescale_scale = gr.Slider(
934
+ label="Audio Rescale", minimum=0.0, maximum=2.0, value=1.0, step=0.1
935
+ )
936
+ audio_v2a_scale = gr.Slider(
937
+ label="V2A Scale", minimum=0.0, maximum=5.0, value=3.0, step=0.1
938
+ )
939
+
940
+ # RIGHT SIDE: Output and LoRA
941
+ with gr.Column():
942
+ output_video = gr.Video(label="Generated Video", autoplay=False)
943
+
944
+ gpu_duration = gr.Slider(
945
+ label="ZeroGPU duration (seconds)",
946
+ minimum=30.0, maximum=240.0, value=90.0, step=1.0,
947
+ info="Increase for longer videos, higher resolution, or LoRA usage"
948
+ )
949
+
950
+ gr.Markdown("### LoRA Adapter Strengths")
951
+ gr.Markdown("Set to 0 to disable, then click 'Prepare LoRA Cache'")
952
+
953
+ with gr.Row():
954
+ distilled_strength = gr.Slider(label="Distilled LoRA", minimum=0.0, maximum=1.5, value=0.0, step=0.01)
955
+ pose_strength = gr.Slider(label="Anthro Enhancer", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
956
+
957
+ with gr.Row():
958
+ general_strength = gr.Slider(label="Reasoning Enhancer", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
959
+ motion_strength = gr.Slider(label="Anthro Posing", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
960
+
961
+ with gr.Row():
962
+ dreamlay_strength = gr.Slider(label="Dreamlay", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
963
+ mself_strength = gr.Slider(label="Mself", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
964
+
965
+ with gr.Row():
966
+ dramatic_strength = gr.Slider(label="Dramatic", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
967
+ fluid_strength = gr.Slider(label="Fluid Helper", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
968
+
969
+ with gr.Row():
970
+ liquid_strength = gr.Slider(label="Liquid Helper", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
971
+ demopose_strength = gr.Slider(label="Audio Helper", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
972
+
973
+ with gr.Row():
974
+ voice_strength = gr.Slider(label="Voice Helper", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
975
+ realism_strength = gr.Slider(label="Anthro Realism", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
976
+
977
+ with gr.Row():
978
+ transition_strength = gr.Slider(label="POV", minimum=0.0, maximum=2.0, value=0.0, step=0.01)
979
+ gr.Markdown("") # Spacer for alignment
980
+
981
+ prepare_lora_btn = gr.Button("Prepare / Load LoRA Cache", variant="secondary")
982
+ lora_status = gr.Textbox(
983
+ label="LoRA Cache Status",
984
+ value="No LoRA state prepared yet.",
985
+ interactive=False,
986
+ )
987
+
988
+ # Event handlers
989
+ first_image.change(fn=on_image_upload, inputs=[first_image, last_image, high_res], outputs=[width, height])
990
+ last_image.change(fn=on_image_upload, inputs=[first_image, last_image, high_res], outputs=[width, height])
991
+ high_res.change(fn=on_highres_toggle, inputs=[first_image, last_image, high_res], outputs=[width, height])
992
+
993
+ prepare_lora_btn.click(
994
+ fn=prepare_lora_cache,
995
+ inputs=[distilled_strength, pose_strength, general_strength, motion_strength, dreamlay_strength,
996
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
997
+ demopose_strength, voice_strength, realism_strength, transition_strength],
998
+ outputs=[lora_status],
999
+ )
1000
+
1001
+ generate_btn.click(
1002
+ fn=generate_video,
1003
+ inputs=[
1004
+ first_image, last_image, prompt, negative_prompt, duration, gpu_duration,
1005
+ seed, randomize_seed, height, width,
1006
+ video_cfg_scale, video_stg_scale, video_rescale_scale, video_a2v_scale,
1007
+ audio_cfg_scale, audio_stg_scale, audio_rescale_scale, audio_v2a_scale,
1008
+ distilled_strength, pose_strength, general_strength, motion_strength,
1009
+ dreamlay_strength, mself_strength, dramatic_strength, fluid_strength,
1010
+ liquid_strength, demopose_strength, voice_strength, realism_strength,
1011
+ transition_strength,
1012
+ ],
1013
+ outputs=[output_video, seed],
1014
+ )
1015
+
1016
+
1017
+ if __name__ == "__main__":
1018
+ demo.queue().launch(theme=gr.themes.Citrus(), css=css, mcp_server=False)