dagloop5 commited on
Commit
0dd62b1
·
verified ·
1 Parent(s): 8002597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -92
app.py CHANGED
@@ -40,6 +40,7 @@ torch._dynamo.config.disable = True
40
  import spaces
41
  import gradio as gr
42
  import numpy as np
 
43
  from huggingface_hub import hf_hub_download, snapshot_download
44
 
45
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
@@ -265,6 +266,17 @@ print("=" * 80)
265
  print("Downloading LTX-2.3 distilled model + Gemma...")
266
  print("=" * 80)
267
 
 
 
 
 
 
 
 
 
 
 
 
268
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
269
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
270
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
@@ -382,7 +394,7 @@ def generate_video(
382
  progress=gr.Progress(track_tqdm=True),
383
  ):
384
  try:
385
- global pipeline # <<< ADD THIS LINE HERE (VERY TOP of try block)
386
  torch.cuda.reset_peak_memory_stats()
387
  log_memory("start")
388
 
@@ -417,75 +429,52 @@ def generate_video(
417
  tiling_config = TilingConfig.default()
418
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
419
 
420
- # >>> RUNTIME LoRA application (robust, multi-fallback)
421
- # We cannot rely on mutating the original descriptor (some implementations are immutable),
422
- # so create a fresh runtime descriptor and try multiple ways to install it.
423
- runtime_strength = float(lora_strength)
424
- replaced = False
425
-
426
- # 1) Try simple approach: build a new LoraPathStrengthAndSDOps
427
- runtime_lora = LoraPathStrengthAndSDOps(lora_path, runtime_strength, LTXV_LORA_COMFY_RENAMING_MAP)
428
- print(f"[LoRA] attempting to apply runtime LoRA (strength={runtime_strength})")
429
-
430
- # Try a few likely places to replace the descriptor used by the pipeline/ledger.
431
- try:
432
- # common attribute on pipeline
433
- if hasattr(pipeline, "loras"):
434
- try:
435
- pipeline.loras = [runtime_lora]
436
- replaced = True
437
- print("[LoRA] replaced pipeline.loras")
438
- except Exception as e:
439
- print(f"[LoRA] pipeline.loras assignment failed: {e}")
440
- except Exception:
441
- pass
442
 
443
- try:
444
- # common attribute on the model ledger
445
- if hasattr(pipeline, "model_ledger") and hasattr(pipeline.model_ledger, "loras"):
 
446
  try:
447
- pipeline.model_ledger.loras = [runtime_lora]
448
- replaced = True
449
- print("[LoRA] replaced pipeline.model_ledger.loras")
450
- except Exception as e:
451
- print(f"[LoRA] pipeline.model_ledger.loras assignment failed: {e}")
452
- except Exception:
453
- pass
454
 
455
- try:
456
- # some internals use a private _loras list
457
- if hasattr(pipeline, "model_ledger") and hasattr(pipeline.model_ledger, "_loras"):
458
- try:
459
- pipeline.model_ledger._loras = [runtime_lora]
460
- replaced = True
461
- print("[LoRA] replaced pipeline.model_ledger._loras")
462
- except Exception as e:
463
- print(f"[LoRA] pipeline.model_ledger._loras assignment failed: {e}")
464
- except Exception:
465
- pass
466
-
467
- # 2) If we succeeded replacing the descriptor in-place, clear transformer cache so it will rebuild
468
- if replaced:
469
- try:
470
- if hasattr(pipeline.model_ledger, "_transformer"):
471
- pipeline.model_ledger._transformer = None
472
- # also clear potential caches named similar to 'transformer_cache' if present
473
- if hasattr(pipeline.model_ledger, "transformer_cache"):
474
  try:
475
- pipeline.model_ledger.transformer_cache = {}
476
  except Exception:
477
  pass
478
- print("[LoRA] in-place descriptor replacement done; transformer cache cleared")
479
- except Exception as e:
480
- print(f"[LoRA] replacement succeeded but cache clearing failed: {e}")
481
-
482
- # 3) FINAL FALLBACK - if none of the in-place replacements worked, rebuild the pipeline
483
- if not replaced:
484
- print("[LoRA] in-place replacement FAILED; rebuilding pipeline with runtime LoRA (this is slow)")
485
- try:
486
- # Rebuild pipeline object with the new LoRA descriptor
487
- # NOTE: this replaces the global `pipeline`. We must declare global to reassign it.
488
- pipeline = LTX23DistilledA2VPipeline(
489
  distilled_checkpoint_path=checkpoint_path,
490
  spatial_upsampler_path=spatial_upsampler_path,
491
  gemma_root=gemma_root,
@@ -493,35 +482,50 @@ def generate_video(
493
  quantization=QuantizationPolicy.fp8_cast(),
494
  )
495
 
496
- # After rebuilding, we *do not* re-run the original module-level preloads here,
497
- # because re-pinning may be complex; the rebuilt pipeline will construct its
498
- # own ledger as part of the first call. This is slower but reliable.
499
- # Clear any transformer caches if they exist on the new ledger as well.
500
  try:
501
- if hasattr(pipeline.model_ledger, "_transformer"):
502
- pipeline.model_ledger._transformer = None
503
- except Exception:
504
- pass
505
-
506
- print("[LoRA] pipeline rebuilt with runtime LoRA")
507
- except Exception as e:
508
- print(f"[LoRA] pipeline rebuild FAILED: {e}")
509
-
510
- # Reset transformer so next call rebuilds it with new LoRA strength (NO preloading!)
511
- try:
512
- if hasattr(pipeline, "model_ledger"):
513
- if hasattr(pipeline.model_ledger, "_transformer"):
514
- del pipeline.model_ledger._transformer
515
- pipeline.model_ledger._transformer = None
516
-
517
- # CRITICAL: force cleanup BEFORE rebuild happens
518
- cleanup_memory()
519
- torch.cuda.empty_cache()
520
-
521
- print("[LoRA] transformer reset; will rebuild during inference")
522
- except Exception as e:
523
- print(f"[LoRA] transformer reset failed: {e}")
524
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  log_memory("before pipeline call")
526
 
527
  video, audio = pipeline(
 
40
  import spaces
41
  import gradio as gr
42
  import numpy as np
43
+ from collections import OrderedDict
44
  from huggingface_hub import hf_hub_download, snapshot_download
45
 
46
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
 
266
  print("Downloading LTX-2.3 distilled model + Gemma...")
267
  print("=" * 80)
268
 
269
+ # ----------------------------
270
+ # Pipeline cache for LoRA strengths (keeps at most 2 pipelines to limit VRAM)
271
+ # ----------------------------
272
+ # Use rounded strengths as keys (2 decimal places)
273
+ pipeline_cache: OrderedDict[float, LTX23DistilledA2VPipeline] = OrderedDict()
274
+ # Record the current pipeline's LoRA strength (we built the module above with lora_descriptor default 1.0)
275
+ current_lora_strength: float = round(1.0, 2)
276
+ pipeline_cache[current_lora_strength] = pipeline
277
+ CACHE_MAX_SIZE = 2 # keep at most two pipeline instances in memory
278
+ print(f"[CACHE] initialized pipeline cache with strength={current_lora_strength}")
279
+
280
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
281
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
282
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
 
394
  progress=gr.Progress(track_tqdm=True),
395
  ):
396
  try:
397
+ global pipeline, pipeline_cache, current_lora_strength
398
  torch.cuda.reset_peak_memory_stats()
399
  log_memory("start")
400
 
 
429
  tiling_config = TilingConfig.default()
430
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
431
 
432
+ # ----------------------------
433
+ # Pipeline-per-strength (small LRU cache) safe, deterministic LoRA switching
434
+ # ----------------------------
435
+ # Globals used: pipeline, pipeline_cache, current_lora_strength, CACHE_MAX_SIZE
436
+ requested_strength = round(float(lora_strength), 2)
437
+
438
+ # fast-path: same strength currently loaded
439
+ if requested_strength == current_lora_strength:
440
+ print(f"[LoRA] requested strength {requested_strength} == current {current_lora_strength} -> using current pipeline")
441
+ else:
442
+ print(f"[LoRA] requested strength {requested_strength} != current {current_lora_strength}")
443
+
444
+ # if cached, swap to that pipeline (move to end to mark as recently used)
445
+ if requested_strength in pipeline_cache:
446
+ print(f"[LoRA] using cached pipeline for strength={requested_strength}")
447
+ # set pipeline to cached instance & mark as most-recently-used
448
+ cached = pipeline_cache.pop(requested_strength)
449
+ pipeline_cache[requested_strength] = cached
450
+ pipeline = cached
451
+ current_lora_strength = requested_strength
 
 
452
 
453
+ else:
454
+ # Build new pipeline for requested strength.
455
+ print(f"[LoRA] building new pipeline for strength={requested_strength} (this will free and reallocate memory)")
456
+ # Free the previous pipeline & its GPU memory BEFORE building the new one
457
  try:
458
+ # remove previous pipeline from cache (if present)
459
+ if current_lora_strength in pipeline_cache:
460
+ pipeline_cache.pop(current_lora_strength, None)
 
 
 
 
461
 
462
+ # delete current pipeline object reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  try:
464
+ del pipeline
465
  except Exception:
466
  pass
467
+
468
+ # aggressively free memory
469
+ cleanup_memory()
470
+ torch.cuda.empty_cache()
471
+ print("[LoRA] freed memory, starting pipeline build")
472
+ except Exception as e:
473
+ print(f"[LoRA] error while freeing old pipeline: {e}")
474
+
475
+ # create a runtime LoRA descriptor and build a fresh pipeline
476
+ runtime_lora = LoraPathStrengthAndSDOps(lora_path, float(requested_strength), LTXV_LORA_COMFY_RENAMING_MAP)
477
+ new_pipeline = LTX23DistilledA2VPipeline(
478
  distilled_checkpoint_path=checkpoint_path,
479
  spatial_upsampler_path=spatial_upsampler_path,
480
  gemma_root=gemma_root,
 
482
  quantization=QuantizationPolicy.fp8_cast(),
483
  )
484
 
485
+ # Pin safe components (same preloads as original) so heavy parts remain stable.
 
 
 
486
  try:
487
+ ledger = new_pipeline.model_ledger
488
+ _video_encoder = ledger.video_encoder()
489
+ _video_decoder = ledger.video_decoder()
490
+ _audio_encoder = ledger.audio_encoder()
491
+ _audio_decoder = ledger.audio_decoder()
492
+ _vocoder = ledger.vocoder()
493
+ _spatial_upsampler = ledger.spatial_upsampler()
494
+ _text_encoder = ledger.text_encoder()
495
+ _embeddings_processor = ledger.gemma_embeddings_processor()
496
+
497
+ ledger.video_encoder = lambda: _video_encoder
498
+ ledger.video_decoder = lambda: _video_decoder
499
+ ledger.audio_encoder = lambda: _audio_encoder
500
+ ledger.audio_decoder = lambda: _audio_decoder
501
+ ledger.vocoder = lambda: _vocoder
502
+ ledger.spatial_upsampler = lambda: _spatial_upsampler
503
+ ledger.text_encoder = lambda: _text_encoder
504
+ ledger.gemma_embeddings_processor = lambda: _embeddings_processor
505
+ print("[LoRA] new pipeline preloaded and pinned safe components")
506
+ except Exception as e:
507
+ print(f"[LoRA] warning preloading pinned components failed: {e}")
508
+
509
+ # Set as current pipeline and cache it
510
+ pipeline = new_pipeline
511
+ pipeline_cache[requested_strength] = pipeline
512
+ current_lora_strength = requested_strength
513
+
514
+ # Evict oldest if cache size exceeded
515
+ try:
516
+ while len(pipeline_cache) > CACHE_MAX_SIZE:
517
+ evicted_strength, evicted_pipeline = pipeline_cache.popitem(last=False)
518
+ try:
519
+ del evicted_pipeline
520
+ except Exception:
521
+ pass
522
+ cleanup_memory()
523
+ torch.cuda.empty_cache()
524
+ print(f"[CACHE] evicted pipeline strength={evicted_strength}")
525
+ except Exception as e:
526
+ print(f"[CACHE] eviction error: {e}")
527
+
528
+ # end of pipeline-per-strength swap/build
529
  log_memory("before pipeline call")
530
 
531
  video, audio = pipeline(