dagloop5 commited on
Commit
b5194ca
·
verified ·
1 Parent(s): cf0e41e

Delete app(draft2).py

Browse files
Files changed (1) hide show
  1. app(draft2).py +0 -800
app(draft2).py DELETED
@@ -1,800 +0,0 @@
1
- # =============================================================================
2
- # Installation and Setup
3
- # =============================================================================
4
- import os
5
- import subprocess
6
- import sys
7
-
8
- # Disable torch.compile / dynamo before any torch import
9
- # This prevents CUDA initialization issues in the Space environment
10
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
11
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
12
-
13
- # Clone LTX-2 repo at specific commit for reproducibility
14
- # The commit ensures we have the exact pipeline code matching our analysis
15
- LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
16
- LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
17
- # Using specific commit for stability - can be updated to main later
18
- LTX_COMMIT_SHA = "a2c3f24078eb918171967f74b6f66b756b29ee45"
19
-
20
- if not os.path.exists(LTX_REPO_DIR):
21
- print(f"Cloning {LTX_REPO_URL} at commit {LTX_COMMIT_SHA}...")
22
- os.makedirs(LTX_REPO_DIR)
23
- subprocess.run(["git", "init", LTX_REPO_DIR], check=True)
24
- subprocess.run(["git", "remote", "add", "origin", LTX_REPO_URL], cwd=LTX_REPO_DIR, check=True)
25
- subprocess.run(["git", "fetch", "--depth", "1", "origin", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
26
- subprocess.run(["git", "checkout", LTX_COMMIT_SHA], cwd=LTX_REPO_DIR, check=True)
27
-
28
- # Add repo packages to Python path
29
- # This allows us to import from ltx-core and ltx-pipelines
30
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-pipelines", "src"))
31
- sys.path.insert(0, os.path.join(LTX_REPO_DIR, "packages", "ltx-core", "src"))
32
-
33
- # =============================================================================
34
- # Imports
35
- # =============================================================================
36
- import logging
37
- import random
38
- import tempfile
39
- from pathlib import Path
40
-
41
- import torch
42
- # Disable torch.compile/dynamo at runtime level
43
- torch._dynamo.config.suppress_errors = True
44
- torch._dynamo.config.disable = True
45
-
46
- import gradio as gr
47
- import spaces
48
- import numpy as np
49
- from huggingface_hub import hf_hub_download, snapshot_download
50
-
51
- # Import from the cloned LTX-2 pipeline
52
- # These imports come from ti2vid_two_stages_hq.py
53
- from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
54
- from ltx_core.quantization import QuantizationPolicy
55
- from ltx_core.loader import LoraPathStrengthAndSDOps
56
- from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
57
- from ltx_pipelines.utils.args import ImageConditioningInput
58
- from ltx_pipelines.utils.media_io import encode_video
59
- from ltx_pipelines.utils.constants import LTX_2_3_HQ_PARAMS
60
- from ltx_core.components.guiders import MultiModalGuiderParams
61
-
62
- # =============================================================================
63
- # Constants and Configuration
64
- # =============================================================================
65
-
66
- # Model repository on Hugging Face
67
- LTX_MODEL_REPO = "Lightricks/LTX-2.3"
68
- GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
69
-
70
- # Default parameters from LTX_2_3_HQ_PARAMS
71
- DEFAULT_FRAME_RATE = 24.0
72
-
73
- # Resolution constraints (must be divisible by 64 for two-stage pipeline)
74
- # The pipeline generates at half-resolution in Stage 1, so input must be divisible by 2
75
- MIN_DIM = 256
76
- MAX_DIM = 1280
77
- STEP = 64 # Both width and height must be divisible by 64
78
-
79
- # Duration constraints (frames must be 8*K + 1)
80
- MIN_FRAMES = 9 # 8*1 + 1
81
- MAX_FRAMES = 257 # 8*32 + 1
82
-
83
- # Seed range
84
- MAX_SEED = np.iinfo(np.int32).max
85
-
86
- # Default prompts
87
- DEFAULT_PROMPT = (
88
- "A majestic eagle soaring over mountain peaks at sunset, "
89
- "wings spread wide against the orange sky, feathers catching the light, "
90
- "wind currents visible in the motion blur, cinematic slow motion, 4K quality"
91
- )
92
- DEFAULT_NEGATIVE_PROMPT = (
93
- "worst quality, inconsistent motion, blurry, jittery, distorted, "
94
- "deformed, artifacts, text, watermark, logo, frame, border, "
95
- "low resolution, pixelated, unnatural, fake, CGI, cartoon"
96
- )
97
-
98
- # =============================================================================
99
- # Model Download and Initialization
100
- # =============================================================================
101
-
102
- print("=" * 80)
103
- print("Downloading LTX-2.3 models...")
104
- print("=" * 80)
105
-
106
- # Download all required model files
107
- # 1. Dev checkpoint - full trainable 22B model
108
- checkpoint_path = hf_hub_download(
109
- repo_id=LTX_MODEL_REPO,
110
- filename="ltx-2.3-22b-dev.safetensors"
111
- )
112
- print(f"Dev checkpoint: {checkpoint_path}")
113
-
114
- # 2. Spatial upscaler - x2 upscaler for latent space
115
- spatial_upsampler_path = hf_hub_download(
116
- repo_id=LTX_MODEL_REPO,
117
- filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors"
118
- )
119
- print(f"Spatial upsampler: {spatial_upsampler_path}")
120
-
121
- # 3. Distilled LoRA - distilled knowledge in LoRA format (rank 384)
122
- # This LoRA is specifically trained to work with the dev model
123
- distilled_lora_path = hf_hub_download(
124
- repo_id=LTX_MODEL_REPO,
125
- filename="ltx-2.3-22b-distilled-lora-384.safetensors"
126
- )
127
- print(f"Distilled LoRA: {distilled_lora_path}")
128
-
129
- # 4. Gemma text encoder - required for prompt encoding
130
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
131
- print(f"Gemma root: {gemma_root}")
132
-
133
- print("=" * 80)
134
- print("All models downloaded!")
135
- print("=" * 80)
136
-
137
- # =============================================================================
138
- # Pipeline Initialization
139
- # =============================================================================
140
-
141
- # Create the LoraPathStrengthAndSDOps for distilled LoRA
142
- # The sd_ops parameter uses the ComfyUI renaming map for compatibility
143
- from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP
144
-
145
- distilled_lora = [
146
- LoraPathStrengthAndSDOps(
147
- path=distilled_lora_path,
148
- strength=1.0, # Will be set per-stage (0.25 for stage 1, 0.5 for stage 2)
149
- sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
150
- )
151
- ]
152
-
153
- # Initialize the Two-Stage HQ Pipeline
154
- # Key parameters:
155
- # - checkpoint_path: Full dev model (trainable)
156
- # - distilled_lora: LoRA containing distilled knowledge
157
- # - distilled_lora_strength_stage_1: 0.25 (lighter application at half-res)
158
- # - distilled_lora_strength_stage_2: 0.5 (stronger application after upscaling)
159
- # - spatial_upsampler_path: Required for two-stage upscaling
160
- # - gemma_root: Gemma text encoder for prompt encoding
161
- print("Initializing LTX-2.3 Two-Stage HQ Pipeline...")
162
-
163
- pipeline = TI2VidTwoStagesHQPipeline(
164
- checkpoint_path=checkpoint_path,
165
- distilled_lora=distilled_lora,
166
- distilled_lora_strength_stage_1=0.25, # From HQ params
167
- distilled_lora_strength_stage_2=0.50, # From HQ params
168
- spatial_upsampler_path=spatial_upsampler_path,
169
- gemma_root=gemma_root,
170
- loras=(), # No additional custom LoRAs for this Space
171
- quantization=QuantizationPolicy.fp8_cast(), # FP8 for memory efficiency
172
- torch_compile=False, # Disable for Space compatibility
173
- )
174
-
175
- print("Pipeline initialized successfully!")
176
- print("=" * 80)
177
-
178
- # =============================================================================
179
- # ZeroGPU Tensor Preloading - CPU Tensor Approach
180
- # =============================================================================
181
- # ZeroGPU should pack any tensors in memory, not just GPU tensors.
182
- # We load model weights to CPU as proxy tensors to trigger packing.
183
- # During actual generation, ZeroGPU will move them to GPU.
184
-
185
- print("Creating CPU proxy tensors for ZeroGPU tensor packing...")
186
- print("This may take a few minutes (loading to CPU only)...")
187
-
188
- import gc
189
-
190
- # Create small proxy tensors for each model component
191
- # These don't need to be the actual weights - just tensors to trigger packing
192
- # ZeroGPU will pack whatever tensors exist when it runs
193
-
194
- _proxy_tensors = []
195
-
196
- def create_proxy(name, shape, dtype=torch.float32):
197
- """Create a proxy tensor and ensure ZeroGPU sees it."""
198
- print(f" Creating proxy for {name}: {shape}")
199
- t = torch.zeros(shape, dtype=dtype)
200
- _proxy_tensors.append(t)
201
- return t
202
-
203
- # Create proxies for various model components
204
- # These are just to ensure tensors exist in memory for ZeroGPU to pack
205
- create_proxy("transformer_stage1", (1, 1024, 512))
206
- create_proxy("transformer_stage2", (1, 1024, 512))
207
- create_proxy("video_encoder", (1, 768, 512))
208
- create_proxy("video_decoder", (1, 512, 512))
209
- create_proxy("audio_decoder", (1, 256, 512))
210
- create_proxy("spatial_upsampler", (1, 256, 512))
211
- create_proxy("text_encoder", (1, 2048, 256))
212
- create_proxy("vocoder", (1, 128, 256))
213
-
214
- # Keep proxies alive by storing in module globals
215
- proxy_stage1 = _proxy_tensors[0]
216
- proxy_stage2 = _proxy_tensors[1]
217
- proxy_venc = _proxy_tensors[2]
218
- proxy_vdec = _proxy_tensors[3]
219
- proxy_adec = _proxy_tensors[4]
220
- proxy_upsamp = _proxy_tensors[5]
221
- proxy_tenc = _proxy_tensors[6]
222
- proxy_voc = _proxy_tensors[7]
223
-
224
- # Clean up the temporary list
225
- del _proxy_tensors
226
-
227
- # Now trigger the actual model loading but catch GPU errors
228
- print("\nAttempting model initialization (GPU errors expected)...")
229
-
230
- try:
231
- # Try to access components - this will trigger loading but fail on GPU
232
- _ = pipeline.stage_1._transformer_ctx
233
- _ = pipeline.prompt_encoder._text_encoder_ctx
234
- print(" Model contexts accessed")
235
- except Exception as e:
236
- print(f" Context access: {type(e).__name__}")
237
-
238
- print("\n" + "=" * 80)
239
- print("Startup complete. Models will load to GPU during first generation.")
240
- print("=" * 80)
241
-
242
- # =============================================================================
243
- # Helper Functions
244
- # =============================================================================
245
-
246
- def log_memory(tag: str):
247
- """Log current GPU memory usage for debugging."""
248
- if torch.cuda.is_available():
249
- allocated = torch.cuda.memory_allocated() / 1024**3
250
- peak = torch.cuda.max_memory_allocated() / 1024**3
251
- free, total = torch.cuda.mem_get_info()
252
- print(f"[VRAM {tag}] allocated={allocated:.2f}GB peak={peak:.2f}GB free={free / 1024**3:.2f}GB total={total / 1024**3:.2f}GB")
253
-
254
-
255
- def calculate_frames(duration: float, frame_rate: float = DEFAULT_FRAME_RATE) -> int:
256
- """
257
- Calculate number of frames from duration.
258
-
259
- Frame count must be 8*K + 1 (K is a non-negative integer) for the LTX model.
260
- This constraint comes from the temporal upsampling architecture.
261
-
262
- Args:
263
- duration: Duration in seconds
264
- frame_rate: Frames per second
265
-
266
- Returns:
267
- Frame count that satisfies the 8*K + 1 constraint
268
- """
269
- ideal_frames = int(duration * frame_rate)
270
- # Ensure it's at least MIN_FRAMES
271
- ideal_frames = max(ideal_frames, MIN_FRAMES)
272
- # Round to nearest 8*K + 1
273
- k = round((ideal_frames - 1) / 8)
274
- frames = k * 8 + 1
275
- # Clamp to max
276
- return min(frames, MAX_FRAMES)
277
-
278
-
279
- def validate_resolution(height: int, width: int) -> tuple[int, int]:
280
- """
281
- Ensure resolution is valid for two-stage pipeline.
282
-
283
- The two-stage pipeline requires:
284
- - Both dimensions divisible by 64 (for final resolution)
285
- - Stage 1 operates at half resolution (divisible by 32)
286
-
287
- Args:
288
- height: Target height
289
- width: Target width
290
-
291
- Returns:
292
- Validated (height, width) tuple
293
- """
294
- # Round to nearest multiple of 64
295
- height = round(height / STEP) * STEP
296
- width = round(width / STEP) * STEP
297
-
298
- # Clamp to valid range
299
- height = max(MIN_DIM, min(height, MAX_DIM))
300
- width = max(MIN_DIM, min(width, MAX_DIM))
301
-
302
- return height, width
303
-
304
-
305
- def detect_aspect_ratio(image) -> str:
306
- """Detect the closest aspect ratio from an image for resolution presets."""
307
- if image is None:
308
- return "16:9"
309
-
310
- if hasattr(image, "size"):
311
- w, h = image.size
312
- elif hasattr(image, "shape"):
313
- h, w = image.shape[:2]
314
- else:
315
- return "16:9"
316
-
317
- ratio = w / h
318
- candidates = {"16:9": 16/9, "9:16": 9/16, "1:1": 1.0}
319
- return min(candidates, key=lambda k: abs(ratio - candidates[k]))
320
-
321
-
322
- # Resolution presets based on aspect ratio
323
- RESOLUTIONS = {
324
- "16:9": {"width": 1280, "height": 704}, # 960x540 * 1.33 = 1280x720, halved = 640x360 -> 1280x720
325
- "9:16": {"width": 704, "height": 1280},
326
- "1:1": {"width": 960, "height": 960},
327
- }
328
-
329
-
330
- def get_duration(
331
- prompt: str,
332
- negative_prompt: str,
333
- input_image,
334
- duration: float,
335
- seed: int,
336
- randomize_seed: bool,
337
- height: int,
338
- width: int,
339
- enhance_prompt: bool,
340
- video_cfg_scale: float,
341
- video_stg_scale: float,
342
- video_rescale_scale: float,
343
- video_a2v_scale: float,
344
- audio_cfg_scale: float,
345
- audio_stg_scale: float,
346
- audio_rescale_scale: float,
347
- audio_v2a_scale: float,
348
- progress,
349
- ) -> int:
350
- """
351
- Dynamically calculate GPU duration based on generation parameters.
352
-
353
- This is used by @spaces.GPU to set the appropriate time limit.
354
- Longer videos and higher resolution require more time.
355
-
356
- Args:
357
- duration: Video duration in seconds
358
- height, width: Resolution
359
- num_frames: Number of frames (indicates complexity)
360
-
361
- Returns:
362
- Duration in seconds for the GPU allocation
363
- """
364
- base = 60
365
-
366
- # Longer videos need more time
367
- if duration > 4:
368
- base += 15
369
- if duration > 6:
370
- base += 15
371
-
372
- # Higher resolution needs more time
373
- if height > 700 or width > 1000:
374
- base += 15
375
-
376
- # More frames means more processing
377
- # Calculate num_frames inside get_duration since it's no longer a parameter
378
- frames_from_duration = int(duration * DEFAULT_FRAME_RATE)
379
- if frames_from_duration > 81:
380
- base += 10
381
-
382
-
383
- @spaces.GPU(duration=get_duration)
384
- @torch.inference_mode()
385
- def generate_video(
386
- prompt: str,
387
- negative_prompt: str,
388
- input_image,
389
- duration: float,
390
- seed: int,
391
- randomize_seed: bool,
392
- height: int,
393
- width: int,
394
- enhance_prompt: bool,
395
- # Guidance parameters
396
- video_cfg_scale: float,
397
- video_stg_scale: float,
398
- video_rescale_scale: float,
399
- video_a2v_scale: float,
400
- audio_cfg_scale: float,
401
- audio_stg_scale: float,
402
- audio_rescale_scale: float,
403
- audio_v2a_scale: float,
404
- progress=gr.Progress(track_tqdm=True),
405
- ):
406
- """
407
- Generate high-quality video using the Two-Stage HQ Pipeline.
408
-
409
- This function implements a two-stage generation process:
410
-
411
- Stage 1 (Half Resolution + CFG):
412
- - Generates video at half the target resolution
413
- - Uses GuidedDenoiser with CFG (positive + negative prompts)
414
- - Applies distilled LoRA at strength 0.25
415
- - Res2s sampler for efficient second-order denoising
416
-
417
- Stage 2 (Upscale + Refine):
418
- - Upscales latent representation 2x using spatial upsampler
419
- - Refines using SimpleDenoiser (no CFG, distilled approach)
420
- - Applies distilled LoRA at strength 0.5
421
- - 4-step refined denoising schedule
422
-
423
- Args:
424
- prompt: Text description of desired video content
425
- negative_prompt: What to avoid in the video
426
- input_image: Optional input image for image-to-video
427
- duration: Video duration in seconds
428
- seed: Random seed for reproducibility
429
- randomize_seed: Whether to use a random seed
430
- height, width: Target resolution (must be divisible by 64)
431
- enhance_prompt: Whether to use prompt enhancement
432
- video_cfg_scale: Video CFG (prompt adherence)
433
- video_stg_scale: Video STG (spatio-temporal guidance)
434
- video_rescale_scale: Video rescaling factor
435
- video_a2v_scale: Audio-to-video cross-attention scale
436
- audio_cfg_scale: Audio CFG (prompt adherence)
437
- audio_stg_scale: Audio STG (spatio-temporal guidance)
438
- audio_rescale_scale: Audio rescaling factor
439
- audio_v2a_scale: Video-to-audio cross-attention scale
440
-
441
- Returns:
442
- Tuple of (output_video_path, used_seed)
443
- """
444
- try:
445
- torch.cuda.reset_peak_memory_stats()
446
- log_memory("start")
447
-
448
- # Handle random seed
449
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
450
- print(f"Using seed: {current_seed}")
451
-
452
- # Validate and adjust resolution
453
- height, width = validate_resolution(int(height), int(width))
454
- print(f"Resolution: {width}x{height}")
455
-
456
- # Calculate frames (must be 8*K + 1)
457
- num_frames = calculate_frames(duration, DEFAULT_FRAME_RATE)
458
- print(f"Frames: {num_frames} ({duration}s @ {DEFAULT_FRAME_RATE}fps)")
459
-
460
- # Prepare image conditioning if provided
461
- images = []
462
- if input_image is not None:
463
- # Save input image temporarily
464
- output_dir = Path("outputs")
465
- output_dir.mkdir(exist_ok=True)
466
- temp_image_path = output_dir / f"temp_input_{current_seed}.jpg"
467
-
468
- if hasattr(input_image, "save"):
469
- input_image.save(temp_image_path)
470
- else:
471
- import shutil
472
- shutil.copy(input_image, temp_image_path)
473
-
474
- # Create ImageConditioningInput
475
- # path: image file path
476
- # frame_idx: target frame to condition on (0 = first frame)
477
- # strength: conditioning strength (1.0 = full influence)
478
- images = [ImageConditioningInput(
479
- path=str(temp_image_path),
480
- frame_idx=0,
481
- strength=1.0
482
- )]
483
-
484
- # Create tiling config for VAE decoding
485
- # Tiling is necessary to avoid OOM errors during decoding
486
- tiling_config = TilingConfig.default()
487
- video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
488
-
489
- # Configure MultiModalGuider parameters
490
- # These control how the model adheres to prompts and handles modality guidance
491
-
492
- # Video guider parameters
493
- # cfg_scale: Classifier-free guidance scale (higher = stronger prompt adherence)
494
- # stg_scale: Spatio-temporal guidance scale (0 = disabled)
495
- # rescale_scale: Rescaling factor for oversaturation prevention
496
- # modality_scale: Cross-attention scale (audio-to-video)
497
- # skip_step: Step skipping for faster inference (0 = no skipping)
498
- # stg_blocks: Which transformer blocks to perturb for STG
499
- video_guider_params = MultiModalGuiderParams(
500
- cfg_scale=video_cfg_scale,
501
- stg_scale=video_stg_scale,
502
- rescale_scale=video_rescale_scale,
503
- modality_scale=video_a2v_scale,
504
- skip_step=0,
505
- stg_blocks=[], # Empty for LTX 2.3 HQ
506
- )
507
-
508
- # Audio guider parameters
509
- audio_guider_params = MultiModalGuiderParams(
510
- cfg_scale=audio_cfg_scale,
511
- stg_scale=audio_stg_scale,
512
- rescale_scale=audio_rescale_scale,
513
- modality_scale=audio_v2a_scale,
514
- skip_step=0,
515
- stg_blocks=[], # Empty for LTX 2.3 HQ
516
- )
517
-
518
- log_memory("before pipeline call")
519
-
520
- # Call the pipeline
521
- # The pipeline uses Res2sDiffusionStep for second-order sampling
522
- # Stage 1: num_inference_steps from LTX_2_3_HQ_PARAMS (15 steps)
523
- # Stage 2: Fixed 4-step schedule from STAGE_2_DISTILLED_SIGMAS
524
- video, audio = pipeline(
525
- prompt=prompt,
526
- negative_prompt=negative_prompt,
527
- seed=current_seed,
528
- height=height,
529
- width=width,
530
- num_frames=num_frames,
531
- frame_rate=DEFAULT_FRAME_RATE,
532
- num_inference_steps=LTX_2_3_HQ_PARAMS.num_inference_steps, # 15 steps
533
- video_guider_params=video_guider_params,
534
- audio_guider_params=audio_guider_params,
535
- images=images,
536
- tiling_config=tiling_config,
537
- enhance_prompt=enhance_prompt,
538
- )
539
-
540
- log_memory("after pipeline call")
541
-
542
- # Encode video with audio
543
- output_path = tempfile.mktemp(suffix=".mp4")
544
- encode_video(
545
- video=video,
546
- fps=DEFAULT_FRAME_RATE,
547
- audio=audio,
548
- output_path=output_path,
549
- video_chunks_number=video_chunks_number,
550
- )
551
-
552
- log_memory("after encode_video")
553
- return str(output_path), current_seed
554
-
555
- except Exception as e:
556
- import traceback
557
- log_memory("on error")
558
- print(f"Error: {str(e)}\n{traceback.format_exc()}")
559
- return None, current_seed
560
-
561
-
562
- # =============================================================================
563
- # Gradio UI
564
- # =============================================================================
565
-
566
- css = """
567
- /* Custom styling for LTX-2.3 Space */
568
- .fillable {max-width: 1200px !important}
569
- .progress-text {color: white}
570
- """
571
-
572
- with gr.Blocks(title="LTX-2.3 Two-Stage HQ Video Generation") as demo:
573
- gr.Markdown("# LTX-2.3 Two-Stage HQ Video Generation")
574
- gr.Markdown(
575
- "High-quality text/image-to-video generation using the dev model + distilled LoRA. "
576
- "[[Model]](https://huggingface.co/Lightricks/LTX-2.3) "
577
- "[[GitHub]](https://github.com/Lightricks/LTX-2)"
578
- )
579
-
580
- with gr.Row():
581
- # Input Column
582
- with gr.Column():
583
- # Input image (optional)
584
- input_image = gr.Image(
585
- label="Input Image (Optional - for image-to-video)",
586
- type="pil",
587
- sources=["upload", "webcam", "clipboard"]
588
- )
589
-
590
- # Prompt inputs
591
- prompt = gr.Textbox(
592
- label="Prompt",
593
- info="Describe the video you want to generate",
594
- value=DEFAULT_PROMPT,
595
- lines=3,
596
- placeholder="Enter your prompt here..."
597
- )
598
-
599
- negative_prompt = gr.Textbox(
600
- label="Negative Prompt",
601
- info="What to avoid in the generated video",
602
- value=DEFAULT_NEGATIVE_PROMPT,
603
- lines=2,
604
- placeholder="Enter negative prompt here..."
605
- )
606
-
607
- # Duration slider
608
- duration = gr.Slider(
609
- label="Duration (seconds)",
610
- minimum=0.5,
611
- maximum=8.0,
612
- value=2.0,
613
- step=0.1,
614
- info="Video duration (clamped to 8K+1 frames)"
615
- )
616
-
617
- # Enhance prompt toggle
618
- enhance_prompt = gr.Checkbox(
619
- label="Enhance Prompt",
620
- value=False,
621
- info="Use Gemma to enhance the prompt for better results"
622
- )
623
-
624
- # Generate button
625
- generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
626
-
627
- # Output Column
628
- with gr.Column():
629
- output_video = gr.Video(
630
- label="Generated Video",
631
- autoplay=True,
632
- interactive=False
633
- )
634
-
635
- # Advanced Settings Accordion
636
- with gr.Accordion("Advanced Settings", open=False):
637
- with gr.Row():
638
- # Resolution inputs
639
- width = gr.Number(
640
- label="Width",
641
- value=1280,
642
- precision=0,
643
- info="Must be divisible by 64"
644
- )
645
- height = gr.Number(
646
- label="Height",
647
- value=704,
648
- precision=0,
649
- info="Must be divisible by 64"
650
- )
651
-
652
- with gr.Row():
653
- # Seed controls
654
- seed = gr.Number(
655
- label="Seed",
656
- value=42,
657
- precision=0,
658
- minimum=0,
659
- maximum=MAX_SEED
660
- )
661
- randomize_seed = gr.Checkbox(
662
- label="Randomize Seed",
663
- value=True
664
- )
665
-
666
- gr.Markdown("### Video Guidance Parameters")
667
- gr.Markdown("Control how strongly the model follows the video prompt and handles guidance.")
668
-
669
- with gr.Row():
670
- video_cfg_scale = gr.Slider(
671
- label="Video CFG Scale",
672
- minimum=1.0,
673
- maximum=10.0,
674
- value=LTX_2_3_HQ_PARAMS.video_guider_params.cfg_scale,
675
- step=0.1,
676
- info="Classifier-free guidance for video (higher = stronger prompt adherence)"
677
- )
678
- video_stg_scale = gr.Slider(
679
- label="Video STG Scale",
680
- minimum=0.0,
681
- maximum=2.0,
682
- value=0.0,
683
- step=0.1,
684
- info="Spatio-temporal guidance (0 = disabled)"
685
- )
686
-
687
- with gr.Row():
688
- video_rescale_scale = gr.Slider(
689
- label="Video Rescale",
690
- minimum=0.0,
691
- maximum=2.0,
692
- value=0.45,
693
- step=0.1,
694
- info="Rescaling factor for oversaturation prevention"
695
- )
696
- video_a2v_scale = gr.Slider(
697
- label="A2V Scale",
698
- minimum=0.0,
699
- maximum=5.0,
700
- value=3.0,
701
- step=0.1,
702
- info="Audio-to-video cross-attention scale"
703
- )
704
-
705
- gr.Markdown("### Audio Guidance Parameters")
706
- gr.Markdown("Control audio generation quality and sync.")
707
-
708
- with gr.Row():
709
- audio_cfg_scale = gr.Slider(
710
- label="Audio CFG Scale",
711
- minimum=1.0,
712
- maximum=15.0,
713
- value=LTX_2_3_HQ_PARAMS.audio_guider_params.cfg_scale,
714
- step=0.1,
715
- info="Classifier-free guidance for audio"
716
- )
717
- audio_stg_scale = gr.Slider(
718
- label="Audio STG Scale",
719
- minimum=0.0,
720
- maximum=2.0,
721
- value=0.0,
722
- step=0.1,
723
- info="Spatio-temporal guidance for audio (0 = disabled)"
724
- )
725
-
726
- with gr.Row():
727
- audio_rescale_scale = gr.Slider(
728
- label="Audio Rescale",
729
- minimum=0.0,
730
- maximum=2.0,
731
- value=1.0,
732
- step=0.1,
733
- info="Audio rescaling factor"
734
- )
735
- audio_v2a_scale = gr.Slider(
736
- label="V2A Scale",
737
- minimum=0.0,
738
- maximum=5.0,
739
- value=3.0,
740
- step=0.1,
741
- info="Video-to-audio cross-attention scale"
742
- )
743
-
744
- # Event handlers
745
- def on_image_upload(image, current_h, current_w):
746
- """Update resolution based on uploaded image aspect ratio."""
747
- if image is None:
748
- return gr.update(), gr.update()
749
-
750
- aspect = detect_aspect_ratio(image)
751
- if aspect in RESOLUTIONS:
752
- return (
753
- gr.update(value=RESOLUTIONS[aspect]["width"]),
754
- gr.update(value=RESOLUTIONS[aspect]["height"])
755
- )
756
- return gr.update(), gr.update()
757
-
758
- input_image.change(
759
- fn=on_image_upload,
760
- inputs=[input_image, height, width],
761
- outputs=[width, height],
762
- )
763
-
764
- # Generate button click handler
765
- generate_btn.click(
766
- fn=generate_video,
767
- inputs=[
768
- prompt,
769
- negative_prompt,
770
- input_image,
771
- duration,
772
- seed,
773
- randomize_seed,
774
- height,
775
- width,
776
- enhance_prompt,
777
- video_cfg_scale,
778
- video_stg_scale,
779
- video_rescale_scale,
780
- video_a2v_scale,
781
- audio_cfg_scale,
782
- audio_stg_scale,
783
- audio_rescale_scale,
784
- audio_v2a_scale,
785
- ],
786
- outputs=[output_video, seed],
787
- )
788
-
789
-
790
- # =============================================================================
791
- # Main Entry Point
792
- # =============================================================================
793
-
794
- if __name__ == "__main__":
795
- demo.queue().launch(
796
- theme=gr.themes.Citrus(),
797
- css=css,
798
- mcp_server=True,
799
- share=True,
800
- )