Update app.py
Browse files
app.py
CHANGED
|
@@ -51,25 +51,29 @@ from safetensors import safe_open
|
|
| 51 |
import json
|
| 52 |
import requests
|
| 53 |
|
| 54 |
-
from ltx_core.components.diffusion_steps import
|
|
|
|
| 55 |
from ltx_core.components.noisers import GaussianNoiser
|
| 56 |
-
from ltx_core.model.
|
| 57 |
-
from ltx_core.
|
| 58 |
-
from
|
| 59 |
-
from
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
from ltx_pipelines.utils.helpers import (
|
| 66 |
-
|
| 67 |
combined_image_conditionings,
|
| 68 |
-
|
| 69 |
-
encode_prompts,
|
| 70 |
-
simple_denoising_func,
|
| 71 |
)
|
| 72 |
-
from ltx_pipelines.utils.media_io import
|
|
|
|
| 73 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 74 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 75 |
|
|
@@ -101,167 +105,169 @@ RESOLUTIONS = {
|
|
| 101 |
}
|
| 102 |
|
| 103 |
|
| 104 |
-
class
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
def __call__(
|
| 108 |
self,
|
| 109 |
prompt: str,
|
|
|
|
| 110 |
seed: int,
|
| 111 |
height: int,
|
| 112 |
width: int,
|
| 113 |
num_frames: int,
|
| 114 |
frame_rate: float,
|
| 115 |
images: list[ImageConditioningInput],
|
| 116 |
-
audio_path: str | None = None,
|
| 117 |
tiling_config: TilingConfig | None = None,
|
| 118 |
enhance_prompt: bool = False,
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
width=width,
|
| 128 |
-
num_frames=num_frames,
|
| 129 |
-
frame_rate=frame_rate,
|
| 130 |
-
images=images,
|
| 131 |
-
tiling_config=tiling_config,
|
| 132 |
-
enhance_prompt=enhance_prompt,
|
| 133 |
-
)
|
| 134 |
|
| 135 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 136 |
noiser = GaussianNoiser(generator=generator)
|
| 137 |
-
stepper = EulerDiffusionStep()
|
| 138 |
dtype = torch.bfloat16
|
| 139 |
|
| 140 |
-
|
| 141 |
-
[prompt],
|
| 142 |
-
self.model_ledger,
|
| 143 |
enhance_first_prompt=enhance_prompt,
|
| 144 |
-
enhance_prompt_image=images[0]
|
|
|
|
|
|
|
| 145 |
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
video_duration = num_frames / frame_rate
|
| 149 |
-
decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
|
| 150 |
-
if decoded_audio is None:
|
| 151 |
-
raise ValueError(f"Could not extract audio stream from {audio_path}")
|
| 152 |
-
|
| 153 |
-
encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
|
| 154 |
-
audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
|
| 155 |
-
expected_frames = audio_shape.frames
|
| 156 |
-
actual_frames = encoded_audio_latent.shape[2]
|
| 157 |
-
|
| 158 |
-
if actual_frames > expected_frames:
|
| 159 |
-
encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
|
| 160 |
-
elif actual_frames < expected_frames:
|
| 161 |
-
pad = torch.zeros(
|
| 162 |
-
encoded_audio_latent.shape[0],
|
| 163 |
-
encoded_audio_latent.shape[1],
|
| 164 |
-
expected_frames - actual_frames,
|
| 165 |
-
encoded_audio_latent.shape[3],
|
| 166 |
-
device=encoded_audio_latent.device,
|
| 167 |
-
dtype=encoded_audio_latent.dtype,
|
| 168 |
-
)
|
| 169 |
-
encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
|
| 170 |
-
|
| 171 |
-
video_encoder = self.model_ledger.video_encoder()
|
| 172 |
-
transformer = self.model_ledger.transformer()
|
| 173 |
-
stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
|
| 174 |
-
|
| 175 |
-
def denoising_loop(sigmas, video_state, audio_state, stepper):
|
| 176 |
-
return euler_denoising_loop(
|
| 177 |
-
sigmas=sigmas,
|
| 178 |
-
video_state=video_state,
|
| 179 |
-
audio_state=audio_state,
|
| 180 |
-
stepper=stepper,
|
| 181 |
-
denoise_fn=simple_denoising_func(
|
| 182 |
-
video_context=video_context,
|
| 183 |
-
audio_context=audio_context,
|
| 184 |
-
transformer=transformer,
|
| 185 |
-
),
|
| 186 |
-
)
|
| 187 |
|
| 188 |
stage_1_output_shape = VideoPixelShape(
|
| 189 |
-
batch=1,
|
| 190 |
-
frames=num_frames,
|
| 191 |
-
width=width // 2,
|
| 192 |
-
height=height // 2,
|
| 193 |
-
fps=frame_rate,
|
| 194 |
)
|
| 195 |
-
stage_1_conditionings =
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
noiser=noiser,
|
| 207 |
-
sigmas=stage_1_sigmas,
|
| 208 |
stepper=stepper,
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
)
|
| 215 |
|
| 216 |
-
|
| 217 |
-
cleanup_memory()
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
height=stage_2_output_shape.height,
|
| 229 |
-
width=stage_2_output_shape.width,
|
| 230 |
-
video_encoder=video_encoder,
|
| 231 |
-
dtype=dtype,
|
| 232 |
-
device=self.device,
|
| 233 |
)
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
| 237 |
noiser=noiser,
|
| 238 |
-
sigmas=stage_2_sigmas,
|
| 239 |
stepper=stepper,
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
)
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
cleanup_memory()
|
| 253 |
-
|
| 254 |
-
decoded_video = vae_decode_video(
|
| 255 |
-
video_state.latent,
|
| 256 |
-
self.model_ledger.video_decoder(),
|
| 257 |
-
tiling_config,
|
| 258 |
-
generator,
|
| 259 |
-
)
|
| 260 |
-
original_audio = Audio(
|
| 261 |
-
waveform=decoded_audio.waveform.squeeze(0),
|
| 262 |
-
sampling_rate=decoded_audio.sampling_rate,
|
| 263 |
-
)
|
| 264 |
-
return decoded_video, original_audio
|
| 265 |
|
| 266 |
|
| 267 |
# Model repos
|
|
@@ -276,11 +282,11 @@ print("=" * 80)
|
|
| 276 |
# LoRA cache directory and currently-applied key
|
| 277 |
LORA_CACHE_DIR = Path("lora_cache")
|
| 278 |
LORA_CACHE_DIR.mkdir(exist_ok=True)
|
| 279 |
-
current_lora_key: str | None = None
|
| 280 |
|
|
|
|
| 281 |
PENDING_LORA_KEY: str | None = None
|
| 282 |
-
|
| 283 |
-
PENDING_LORA_STATUS: str = "No LoRA
|
| 284 |
|
| 285 |
weights_dir = Path("weights")
|
| 286 |
weights_dir.mkdir(exist_ok=True)
|
|
@@ -376,29 +382,19 @@ def prepare_lora_cache(
|
|
| 376 |
progress=gr.Progress(track_tqdm=True),
|
| 377 |
):
|
| 378 |
"""
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
- loads cached fused transformer state_dict, or
|
| 382 |
-
- builds fused transformer on CPU and saves it
|
| 383 |
-
The resulting state_dict is stored in memory and can be applied later.
|
| 384 |
"""
|
| 385 |
-
global PENDING_LORA_KEY,
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
state = load_file(str(cache_path))
|
| 396 |
-
PENDING_LORA_KEY = key
|
| 397 |
-
PENDING_LORA_STATE = state
|
| 398 |
-
PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
|
| 399 |
-
return PENDING_LORA_STATUS
|
| 400 |
-
except Exception as e:
|
| 401 |
-
print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
|
| 402 |
|
| 403 |
entries = [
|
| 404 |
(pose_lora_path, round(float(pose_strength), 2)),
|
|
@@ -414,6 +410,7 @@ def prepare_lora_cache(
|
|
| 414 |
(realism_lora_path, round(float(realism_strength), 2)),
|
| 415 |
(transition_lora_path, round(float(transition_strength), 2)),
|
| 416 |
]
|
|
|
|
| 417 |
loras_for_builder = [
|
| 418 |
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 419 |
for path, strength in entries
|
|
@@ -422,35 +419,31 @@ def prepare_lora_cache(
|
|
| 422 |
|
| 423 |
if not loras_for_builder:
|
| 424 |
PENDING_LORA_KEY = None
|
| 425 |
-
|
| 426 |
PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
|
| 427 |
return PENDING_LORA_STATUS
|
| 428 |
|
| 429 |
-
tmp_ledger = None
|
| 430 |
-
new_transformer_cpu = None
|
| 431 |
try:
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
}
|
| 449 |
-
save_file(state, str(cache_path))
|
| 450 |
|
| 451 |
PENDING_LORA_KEY = key
|
| 452 |
-
|
| 453 |
-
PENDING_LORA_STATUS = f"
|
| 454 |
return PENDING_LORA_STATUS
|
| 455 |
|
| 456 |
except Exception as e:
|
|
@@ -458,45 +451,32 @@ def prepare_lora_cache(
|
|
| 458 |
print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
|
| 459 |
print(traceback.format_exc())
|
| 460 |
PENDING_LORA_KEY = None
|
| 461 |
-
|
| 462 |
PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
|
| 463 |
return PENDING_LORA_STATUS
|
| 464 |
|
| 465 |
-
finally:
|
| 466 |
-
try:
|
| 467 |
-
del new_transformer_cpu
|
| 468 |
-
except Exception:
|
| 469 |
-
pass
|
| 470 |
-
try:
|
| 471 |
-
del tmp_ledger
|
| 472 |
-
except Exception:
|
| 473 |
-
pass
|
| 474 |
-
gc.collect()
|
| 475 |
|
|
|
|
|
|
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
Fast step: copy the already prepared CPU state into the live transformer.
|
| 480 |
-
This is the only part that should remain near generation time.
|
| 481 |
-
"""
|
| 482 |
-
global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
|
| 483 |
-
|
| 484 |
-
if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
|
| 485 |
-
print("[LoRA] No prepared LoRA state available; skipping.")
|
| 486 |
return False
|
| 487 |
|
| 488 |
if current_lora_key == PENDING_LORA_KEY:
|
| 489 |
-
print("[LoRA] Prepared LoRA
|
| 490 |
return True
|
| 491 |
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
|
|
|
|
|
|
| 497 |
|
| 498 |
current_lora_key = PENDING_LORA_KEY
|
| 499 |
-
print("[LoRA] Prepared LoRA
|
| 500 |
return True
|
| 501 |
|
| 502 |
# ---- REPLACE PRELOAD BLOCK START ----
|
|
@@ -588,8 +568,8 @@ def on_highres_toggle(first_image, last_image, high_res):
|
|
| 588 |
def get_gpu_duration(
|
| 589 |
first_image,
|
| 590 |
last_image,
|
| 591 |
-
input_audio,
|
| 592 |
prompt: str,
|
|
|
|
| 593 |
duration: float,
|
| 594 |
gpu_duration: float,
|
| 595 |
enhance_prompt: bool = True,
|
|
@@ -618,8 +598,8 @@ def get_gpu_duration(
|
|
| 618 |
def generate_video(
|
| 619 |
first_image,
|
| 620 |
last_image,
|
| 621 |
-
input_audio,
|
| 622 |
prompt: str,
|
|
|
|
| 623 |
duration: float,
|
| 624 |
gpu_duration: float,
|
| 625 |
enhance_prompt: bool = True,
|
|
@@ -682,15 +662,18 @@ def generate_video(
|
|
| 682 |
|
| 683 |
video, audio = pipeline(
|
| 684 |
prompt=prompt,
|
|
|
|
| 685 |
seed=current_seed,
|
| 686 |
height=int(height),
|
| 687 |
width=int(width),
|
| 688 |
num_frames=num_frames,
|
| 689 |
frame_rate=frame_rate,
|
| 690 |
images=images,
|
| 691 |
-
audio_path=input_audio,
|
| 692 |
tiling_config=tiling_config,
|
| 693 |
enhance_prompt=enhance_prompt,
|
|
|
|
|
|
|
|
|
|
| 694 |
)
|
| 695 |
|
| 696 |
log_memory("after pipeline call")
|
|
@@ -723,7 +706,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
|
|
| 723 |
with gr.Row():
|
| 724 |
first_image = gr.Image(label="First Frame (Optional)", type="pil")
|
| 725 |
last_image = gr.Image(label="Last Frame (Optional)", type="pil")
|
| 726 |
-
input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
|
| 727 |
prompt = gr.Textbox(
|
| 728 |
label="Prompt",
|
| 729 |
info="for best results - make it as elaborate as possible",
|
|
@@ -731,6 +713,12 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
|
|
| 731 |
lines=3,
|
| 732 |
placeholder="Describe the motion and animation you want...",
|
| 733 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=30.0, value=10.0, step=0.1)
|
| 735 |
|
| 736 |
|
|
@@ -817,13 +805,13 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
|
|
| 817 |
[
|
| 818 |
None,
|
| 819 |
"pinkknit.jpg",
|
| 820 |
-
None,
|
| 821 |
"The camera falls downward through darkness as if dropped into a tunnel. "
|
| 822 |
"As it slows, five friends wearing pink knitted hats and sunglasses lean "
|
| 823 |
"over and look down toward the camera with curious expressions. The lens "
|
| 824 |
"has a strong fisheye effect, creating a circular frame around them. They "
|
| 825 |
"crowd together closely, forming a symmetrical cluster while staring "
|
| 826 |
"directly into the lens.",
|
|
|
|
| 827 |
3.0,
|
| 828 |
80.0,
|
| 829 |
False,
|
|
@@ -846,7 +834,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
|
|
| 846 |
],
|
| 847 |
],
|
| 848 |
inputs=[
|
| 849 |
-
first_image, last_image,
|
| 850 |
enhance_prompt, seed, randomize_seed, height, width,
|
| 851 |
pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
|
| 852 |
],
|
|
@@ -879,7 +867,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
|
|
| 879 |
generate_btn.click(
|
| 880 |
fn=generate_video,
|
| 881 |
inputs=[
|
| 882 |
-
first_image, last_image,
|
| 883 |
seed, randomize_seed, height, width,
|
| 884 |
pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
|
| 885 |
],
|
|
|
|
| 51 |
import json
|
| 52 |
import requests
|
| 53 |
|
| 54 |
+
from ltx_core.components.diffusion_steps import Res2sDiffusionStep
|
| 55 |
+
from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
|
| 56 |
from ltx_core.components.noisers import GaussianNoiser
|
| 57 |
+
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
|
| 58 |
+
from ltx_core.types import Audio, VideoLatentShape, VideoPixelShape
|
| 59 |
+
from ltx_pipelines.utils.args import ImageConditioningInput, hq_2_stage_arg_parser
|
| 60 |
+
from ltx_pipelines.utils.blocks import (
|
| 61 |
+
AudioDecoder,
|
| 62 |
+
DiffusionStage,
|
| 63 |
+
ImageConditioner,
|
| 64 |
+
PromptEncoder,
|
| 65 |
+
VideoDecoder,
|
| 66 |
+
VideoUpsampler,
|
| 67 |
+
)
|
| 68 |
+
from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS, STAGE_2_DISTILLED_SIGMAS
|
| 69 |
+
from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser
|
| 70 |
from ltx_pipelines.utils.helpers import (
|
| 71 |
+
assert_resolution,
|
| 72 |
combined_image_conditionings,
|
| 73 |
+
get_device,
|
|
|
|
|
|
|
| 74 |
)
|
| 75 |
+
from ltx_pipelines.utils.media_io import encode_video
|
| 76 |
+
from ltx_pipelines.utils.samplers import res2s_audio_video_denoising_loop
|
| 77 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 78 |
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 79 |
|
|
|
|
| 105 |
}
|
| 106 |
|
| 107 |
|
| 108 |
+
class LTX23NegativePromptTwoStagePipeline:
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
checkpoint_path: str,
|
| 112 |
+
spatial_upsampler_path: str,
|
| 113 |
+
gemma_root: str,
|
| 114 |
+
loras: tuple[LoraPathStrengthAndSDOps, ...],
|
| 115 |
+
device: torch.device | None = None,
|
| 116 |
+
quantization: QuantizationPolicy | None = None,
|
| 117 |
+
registry: Registry | None = None,
|
| 118 |
+
torch_compile: bool = False,
|
| 119 |
+
):
|
| 120 |
+
self.device = device or get_device()
|
| 121 |
+
self.dtype = torch.bfloat16
|
| 122 |
+
self._scheduler = LTX2Scheduler()
|
| 123 |
+
|
| 124 |
+
self.prompt_encoder = PromptEncoder(checkpoint_path, gemma_root, self.dtype, self.device, registry=registry)
|
| 125 |
+
self.image_conditioner = ImageConditioner(checkpoint_path, self.dtype, self.device, registry=registry)
|
| 126 |
+
self.upsampler = VideoUpsampler(checkpoint_path, spatial_upsampler_path, self.dtype, self.device, registry=registry)
|
| 127 |
+
self.video_decoder = VideoDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
|
| 128 |
+
self.audio_decoder = AudioDecoder(checkpoint_path, self.dtype, self.device, registry=registry)
|
| 129 |
+
|
| 130 |
+
self.stage_1 = DiffusionStage(
|
| 131 |
+
checkpoint_path,
|
| 132 |
+
self.dtype,
|
| 133 |
+
self.device,
|
| 134 |
+
loras=tuple(loras),
|
| 135 |
+
quantization=quantization,
|
| 136 |
+
registry=registry,
|
| 137 |
+
torch_compile=torch_compile,
|
| 138 |
+
)
|
| 139 |
+
self.stage_2 = DiffusionStage(
|
| 140 |
+
checkpoint_path,
|
| 141 |
+
self.dtype,
|
| 142 |
+
self.device,
|
| 143 |
+
loras=tuple(loras),
|
| 144 |
+
quantization=quantization,
|
| 145 |
+
registry=registry,
|
| 146 |
+
torch_compile=torch_compile,
|
| 147 |
+
)
|
| 148 |
|
| 149 |
def __call__(
|
| 150 |
self,
|
| 151 |
prompt: str,
|
| 152 |
+
negative_prompt: str,
|
| 153 |
seed: int,
|
| 154 |
height: int,
|
| 155 |
width: int,
|
| 156 |
num_frames: int,
|
| 157 |
frame_rate: float,
|
| 158 |
images: list[ImageConditioningInput],
|
|
|
|
| 159 |
tiling_config: TilingConfig | None = None,
|
| 160 |
enhance_prompt: bool = False,
|
| 161 |
+
streaming_prefetch_count: int | None = None,
|
| 162 |
+
max_batch_size: int = 1,
|
| 163 |
+
stage_1_sigmas: torch.Tensor | None = None,
|
| 164 |
+
stage_2_sigmas: torch.Tensor = STAGE_2_DISTILLED_SIGMAS,
|
| 165 |
+
video_guider_params: MultiModalGuiderParams | None = None,
|
| 166 |
+
audio_guider_params: MultiModalGuiderParams | None = None,
|
| 167 |
+
) -> tuple[Iterator[torch.Tensor], Audio]:
|
| 168 |
+
assert_resolution(height=height, width=width, is_two_stage=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 171 |
noiser = GaussianNoiser(generator=generator)
|
|
|
|
| 172 |
dtype = torch.bfloat16
|
| 173 |
|
| 174 |
+
ctx_p, ctx_n = self.prompt_encoder(
|
| 175 |
+
[prompt, negative_prompt],
|
|
|
|
| 176 |
enhance_first_prompt=enhance_prompt,
|
| 177 |
+
enhance_prompt_image=images[0][0] if len(images) > 0 else None,
|
| 178 |
+
enhance_prompt_seed=seed,
|
| 179 |
+
streaming_prefetch_count=streaming_prefetch_count,
|
| 180 |
)
|
| 181 |
+
v_context_p, a_context_p = ctx_p.video_encoding, ctx_p.audio_encoding
|
| 182 |
+
v_context_n, a_context_n = ctx_n.video_encoding, ctx_n.audio_encoding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
stage_1_output_shape = VideoPixelShape(
|
| 185 |
+
batch=1, frames=num_frames, width=width // 2, height=height // 2, fps=frame_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
)
|
| 187 |
+
stage_1_conditionings = self.image_conditioner(
|
| 188 |
+
lambda enc: combined_image_conditionings(
|
| 189 |
+
images=images,
|
| 190 |
+
height=stage_1_output_shape.height,
|
| 191 |
+
width=stage_1_output_shape.width,
|
| 192 |
+
video_encoder=enc,
|
| 193 |
+
dtype=dtype,
|
| 194 |
+
device=self.device,
|
| 195 |
+
)
|
| 196 |
)
|
| 197 |
+
|
| 198 |
+
stepper = Res2sDiffusionStep()
|
| 199 |
+
if stage_1_sigmas is None:
|
| 200 |
+
empty_latent = torch.empty(VideoLatentShape.from_pixel_shape(stage_1_output_shape).to_torch_shape())
|
| 201 |
+
stage_1_sigmas = self._scheduler.execute(latent=empty_latent, steps=num_inference_steps)
|
| 202 |
+
sigmas = stage_1_sigmas.to(dtype=torch.float32, device=self.device)
|
| 203 |
+
|
| 204 |
+
video_state, audio_state = self.stage_1(
|
| 205 |
+
denoiser=GuidedDenoiser(
|
| 206 |
+
v_context=v_context_p,
|
| 207 |
+
a_context=a_context_p,
|
| 208 |
+
video_guider=MultiModalGuider(
|
| 209 |
+
params=video_guider_params,
|
| 210 |
+
negative_context=v_context_n,
|
| 211 |
+
),
|
| 212 |
+
audio_guider=MultiModalGuider(
|
| 213 |
+
params=audio_guider_params,
|
| 214 |
+
negative_context=a_context_n,
|
| 215 |
+
),
|
| 216 |
+
),
|
| 217 |
+
sigmas=sigmas,
|
| 218 |
noiser=noiser,
|
|
|
|
| 219 |
stepper=stepper,
|
| 220 |
+
width=stage_1_output_shape.width,
|
| 221 |
+
height=stage_1_output_shape.height,
|
| 222 |
+
frames=num_frames,
|
| 223 |
+
fps=frame_rate,
|
| 224 |
+
video=ModalitySpec(context=v_context_p, conditionings=stage_1_conditionings),
|
| 225 |
+
audio=ModalitySpec(context=a_context_p),
|
| 226 |
+
loop=res2s_audio_video_denoising_loop,
|
| 227 |
+
streaming_prefetch_count=streaming_prefetch_count,
|
| 228 |
+
max_batch_size=max_batch_size,
|
| 229 |
)
|
| 230 |
|
| 231 |
+
upscaled_video_latent = self.upsampler(video_state.latent[:1])
|
|
|
|
| 232 |
|
| 233 |
+
stage_2_conditionings = self.image_conditioner(
|
| 234 |
+
lambda enc: combined_image_conditionings(
|
| 235 |
+
images=images,
|
| 236 |
+
height=height,
|
| 237 |
+
width=width,
|
| 238 |
+
video_encoder=enc,
|
| 239 |
+
dtype=dtype,
|
| 240 |
+
device=self.device,
|
| 241 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
)
|
| 243 |
+
|
| 244 |
+
video_state, audio_state = self.stage_2(
|
| 245 |
+
denoiser=SimpleDenoiser(v_context=v_context_p, a_context=a_context_p),
|
| 246 |
+
sigmas=stage_2_sigmas.to(dtype=torch.float32, device=self.device),
|
| 247 |
noiser=noiser,
|
|
|
|
| 248 |
stepper=stepper,
|
| 249 |
+
width=width,
|
| 250 |
+
height=height,
|
| 251 |
+
frames=num_frames,
|
| 252 |
+
fps=frame_rate,
|
| 253 |
+
video=ModalitySpec(
|
| 254 |
+
context=v_context_p,
|
| 255 |
+
conditionings=stage_2_conditionings,
|
| 256 |
+
noise_scale=stage_2_sigmas[0].item(),
|
| 257 |
+
initial_latent=upscaled_video_latent,
|
| 258 |
+
),
|
| 259 |
+
audio=ModalitySpec(
|
| 260 |
+
context=a_context_p,
|
| 261 |
+
noise_scale=stage_2_sigmas[0].item(),
|
| 262 |
+
initial_latent=audio_state.latent,
|
| 263 |
+
),
|
| 264 |
+
loop=res2s_audio_video_denoising_loop,
|
| 265 |
+
streaming_prefetch_count=streaming_prefetch_count,
|
| 266 |
)
|
| 267 |
|
| 268 |
+
decoded_video = self.video_decoder(video_state.latent, tiling_config, generator)
|
| 269 |
+
decoded_audio = self.audio_decoder(audio_state.latent)
|
| 270 |
+
return decoded_video, decoded_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
|
| 273 |
# Model repos
|
|
|
|
| 282 |
# LoRA cache directory and currently-applied key
|
| 283 |
LORA_CACHE_DIR = Path("lora_cache")
|
| 284 |
LORA_CACHE_DIR.mkdir(exist_ok=True)
|
|
|
|
| 285 |
|
| 286 |
+
current_lora_key: str | None = None
|
| 287 |
PENDING_LORA_KEY: str | None = None
|
| 288 |
+
PENDING_LORA_LORAS: tuple[LoraPathStrengthAndSDOps, ...] | None = None
|
| 289 |
+
PENDING_LORA_STATUS: str = "No LoRA config prepared yet."
|
| 290 |
|
| 291 |
weights_dir = Path("weights")
|
| 292 |
weights_dir.mkdir(exist_ok=True)
|
|
|
|
| 382 |
progress=gr.Progress(track_tqdm=True),
|
| 383 |
):
|
| 384 |
"""
|
| 385 |
+
Prepare the LoRA selection for the guided pipeline.
|
| 386 |
+
This caches the LoRA config, not fused weights.
|
|
|
|
|
|
|
|
|
|
| 387 |
"""
|
| 388 |
+
global PENDING_LORA_KEY, PENDING_LORA_LORAS, PENDING_LORA_STATUS
|
| 389 |
+
|
| 390 |
+
key = _make_lora_key(
|
| 391 |
+
pose_strength, general_strength, motion_strength, dreamlay_strength,
|
| 392 |
+
mself_strength, dramatic_strength, fluid_strength, liquid_strength,
|
| 393 |
+
demopose_strength, voice_strength, realism_strength, transition_strength
|
| 394 |
+
)
|
| 395 |
+
cache_path = LORA_CACHE_DIR / f"{key}.json"
|
| 396 |
+
|
| 397 |
+
progress(0.05, desc="Preparing LoRA config")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
|
| 399 |
entries = [
|
| 400 |
(pose_lora_path, round(float(pose_strength), 2)),
|
|
|
|
| 410 |
(realism_lora_path, round(float(realism_strength), 2)),
|
| 411 |
(transition_lora_path, round(float(transition_strength), 2)),
|
| 412 |
]
|
| 413 |
+
|
| 414 |
loras_for_builder = [
|
| 415 |
LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
|
| 416 |
for path, strength in entries
|
|
|
|
| 419 |
|
| 420 |
if not loras_for_builder:
|
| 421 |
PENDING_LORA_KEY = None
|
| 422 |
+
PENDING_LORA_LORAS = None
|
| 423 |
PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
|
| 424 |
return PENDING_LORA_STATUS
|
| 425 |
|
|
|
|
|
|
|
| 426 |
try:
|
| 427 |
+
if cache_path.exists():
|
| 428 |
+
progress(0.20, desc="Loading cached LoRA config")
|
| 429 |
+
data = json.loads(cache_path.read_text())
|
| 430 |
+
loras_for_builder = [
|
| 431 |
+
LoraPathStrengthAndSDOps(item["path"], item["strength"], LTXV_LORA_COMFY_RENAMING_MAP)
|
| 432 |
+
for item in data
|
| 433 |
+
if float(item["strength"]) != 0.0
|
| 434 |
+
]
|
| 435 |
+
else:
|
| 436 |
+
progress(0.30, desc="Saving LoRA config cache")
|
| 437 |
+
cache_path.write_text(
|
| 438 |
+
json.dumps(
|
| 439 |
+
[{"path": path, "strength": strength} for path, strength in entries if float(strength) != 0.0],
|
| 440 |
+
indent=2,
|
| 441 |
+
)
|
| 442 |
+
)
|
|
|
|
|
|
|
| 443 |
|
| 444 |
PENDING_LORA_KEY = key
|
| 445 |
+
PENDING_LORA_LORAS = tuple(loras_for_builder)
|
| 446 |
+
PENDING_LORA_STATUS = f"Prepared LoRA config: {cache_path.name}"
|
| 447 |
return PENDING_LORA_STATUS
|
| 448 |
|
| 449 |
except Exception as e:
|
|
|
|
| 451 |
print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
|
| 452 |
print(traceback.format_exc())
|
| 453 |
PENDING_LORA_KEY = None
|
| 454 |
+
PENDING_LORA_LORAS = None
|
| 455 |
PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
|
| 456 |
return PENDING_LORA_STATUS
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
+
def apply_prepared_lora_config_to_pipeline():
|
| 460 |
+
global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_LORAS, pipeline
|
| 461 |
|
| 462 |
+
if PENDING_LORA_LORAS is None or PENDING_LORA_KEY is None:
|
| 463 |
+
print("[LoRA] No prepared LoRA config available; skipping.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
return False
|
| 465 |
|
| 466 |
if current_lora_key == PENDING_LORA_KEY:
|
| 467 |
+
print("[LoRA] Prepared LoRA config already active; skipping.")
|
| 468 |
return True
|
| 469 |
|
| 470 |
+
pipeline = LTX23NegativePromptTwoStagePipeline(
|
| 471 |
+
checkpoint_path=str(checkpoint_path),
|
| 472 |
+
spatial_upsampler_path=str(spatial_upsampler_path),
|
| 473 |
+
gemma_root=str(gemma_root),
|
| 474 |
+
loras=PENDING_LORA_LORAS,
|
| 475 |
+
quantization=QuantizationPolicy.fp8_cast(),
|
| 476 |
+
)
|
| 477 |
|
| 478 |
current_lora_key = PENDING_LORA_KEY
|
| 479 |
+
print("[LoRA] Prepared LoRA config applied by rebuilding the pipeline.")
|
| 480 |
return True
|
| 481 |
|
| 482 |
# ---- REPLACE PRELOAD BLOCK START ----
|
|
|
|
| 568 |
def get_gpu_duration(
|
| 569 |
first_image,
|
| 570 |
last_image,
|
|
|
|
| 571 |
prompt: str,
|
| 572 |
+
negative_prompt: str,
|
| 573 |
duration: float,
|
| 574 |
gpu_duration: float,
|
| 575 |
enhance_prompt: bool = True,
|
|
|
|
| 598 |
def generate_video(
|
| 599 |
first_image,
|
| 600 |
last_image,
|
|
|
|
| 601 |
prompt: str,
|
| 602 |
+
negative_prompt: str,
|
| 603 |
duration: float,
|
| 604 |
gpu_duration: float,
|
| 605 |
enhance_prompt: bool = True,
|
|
|
|
| 662 |
|
| 663 |
video, audio = pipeline(
|
| 664 |
prompt=prompt,
|
| 665 |
+
negative_prompt=negative_prompt,
|
| 666 |
seed=current_seed,
|
| 667 |
height=int(height),
|
| 668 |
width=int(width),
|
| 669 |
num_frames=num_frames,
|
| 670 |
frame_rate=frame_rate,
|
| 671 |
images=images,
|
|
|
|
| 672 |
tiling_config=tiling_config,
|
| 673 |
enhance_prompt=enhance_prompt,
|
| 674 |
+
# if your wrapper exposes them:
|
| 675 |
+
video_guider_params=video_guider_params,
|
| 676 |
+
audio_guider_params=audio_guider_params,
|
| 677 |
)
|
| 678 |
|
| 679 |
log_memory("after pipeline call")
|
|
|
|
| 706 |
with gr.Row():
|
| 707 |
first_image = gr.Image(label="First Frame (Optional)", type="pil")
|
| 708 |
last_image = gr.Image(label="Last Frame (Optional)", type="pil")
|
|
|
|
| 709 |
prompt = gr.Textbox(
|
| 710 |
label="Prompt",
|
| 711 |
info="for best results - make it as elaborate as possible",
|
|
|
|
| 713 |
lines=3,
|
| 714 |
placeholder="Describe the motion and animation you want...",
|
| 715 |
)
|
| 716 |
+
negative_prompt = gr.Textbox(
|
| 717 |
+
label="Negative Prompt",
|
| 718 |
+
value="",
|
| 719 |
+
lines=2,
|
| 720 |
+
placeholder="Describe what you want to avoid...",
|
| 721 |
+
)
|
| 722 |
duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=30.0, value=10.0, step=0.1)
|
| 723 |
|
| 724 |
|
|
|
|
| 805 |
[
|
| 806 |
None,
|
| 807 |
"pinkknit.jpg",
|
|
|
|
| 808 |
"The camera falls downward through darkness as if dropped into a tunnel. "
|
| 809 |
"As it slows, five friends wearing pink knitted hats and sunglasses lean "
|
| 810 |
"over and look down toward the camera with curious expressions. The lens "
|
| 811 |
"has a strong fisheye effect, creating a circular frame around them. They "
|
| 812 |
"crowd together closely, forming a symmetrical cluster while staring "
|
| 813 |
"directly into the lens.",
|
| 814 |
+
"",
|
| 815 |
3.0,
|
| 816 |
80.0,
|
| 817 |
False,
|
|
|
|
| 834 |
],
|
| 835 |
],
|
| 836 |
inputs=[
|
| 837 |
+
first_image, last_image, prompt, negative_prompt, duration, gpu_duration,
|
| 838 |
enhance_prompt, seed, randomize_seed, height, width,
|
| 839 |
pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
|
| 840 |
],
|
|
|
|
| 867 |
generate_btn.click(
|
| 868 |
fn=generate_video,
|
| 869 |
inputs=[
|
| 870 |
+
first_image, last_image, prompt, negative_prompt, duration, gpu_duration, enhance_prompt,
|
| 871 |
seed, randomize_seed, height, width,
|
| 872 |
pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength,
|
| 873 |
],
|