dagloop5 commited on
Commit
60a179a
·
verified ·
1 Parent(s): f88c8f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -165
app.py CHANGED
@@ -252,10 +252,6 @@ LORA_CACHE_DIR = Path("lora_cache")
252
  LORA_CACHE_DIR.mkdir(exist_ok=True)
253
  current_lora_key: str | None = None
254
 
255
- PENDING_LORA_KEY: str | None = None
256
- PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
257
- PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
258
-
259
  weights_dir = Path("weights")
260
  weights_dir.mkdir(exist_ok=True)
261
  checkpoint_path = hf_hub_download(
@@ -264,7 +260,6 @@ checkpoint_path = hf_hub_download(
264
  local_dir=str(weights_dir),
265
  local_dir_use_symlinks=False,
266
  )
267
- spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
268
 
269
  print("[Gemma] Setting up abliterated Gemma text encoder...")
270
  MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
@@ -367,7 +362,6 @@ print(f"Transition LoRA: {transition_lora_path}")
367
  # ----------------------------------------------------------------
368
 
369
  print(f"Checkpoint: {checkpoint_path}")
370
- print(f"Spatial upsampler: {spatial_upsampler_path}")
371
 
372
  # Initialize pipeline WITH text encoder and optional audio support
373
  # ---- Replace block (pipeline init) lines 275-281 ----
@@ -380,162 +374,141 @@ pipeline = LTX23DistilledA2VPipeline(
380
  )
381
  # ----------------------------------------------------------------
382
 
383
- def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float) -> tuple[str, str]:
384
- rp = round(float(pose_strength), 2)
385
- rg = round(float(general_strength), 2)
386
- rm = round(float(motion_strength), 2)
387
- rd = round(float(dreamlay_strength), 2)
388
- rs = round(float(mself_strength), 2)
389
- rr = round(float(dramatic_strength), 2)
390
- rf = round(float(fluid_strength), 2)
391
- rl = round(float(liquid_strength), 2)
392
- ro = round(float(demopose_strength), 2)
393
- rv = round(float(voice_strength), 2)
394
- re = round(float(realism_strength), 2)
395
- rt = round(float(transition_strength), 2)
396
- key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}|{dreamlay_lora_path}:{rd}|{mself_lora_path}:{rs}|{dramatic_lora_path}:{rr}|{fluid_lora_path}:{rf}|{liquid_lora_path}:{rl}|{demopose_lora_path}:{ro}|{voice_lora_path}:{rv}|{realism_lora_path}:{re}|{transition_lora_path}:{rt}"
397
- key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
398
- return key, key_str
399
-
400
-
401
- def prepare_lora_cache(
402
- pose_strength: float,
403
- general_strength: float,
404
- motion_strength: float,
405
- dreamlay_strength: float,
406
- mself_strength: float,
407
- dramatic_strength: float,
408
- fluid_strength: float,
409
- liquid_strength: float,
410
- demopose_strength: float,
411
- voice_strength: float,
412
- realism_strength: float,
413
- transition_strength: float,
414
- progress=gr.Progress(track_tqdm=True),
415
- ):
416
- """
417
- CPU-only step:
418
- - checks cache
419
- - loads cached fused transformer state_dict, or
420
- - builds fused transformer on CPU and saves it
421
- The resulting state_dict is stored in memory and can be applied later.
422
- """
423
- global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
424
-
425
- ledger = pipeline.model_ledger
426
- key, _ = _make_lora_key(pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength)
427
- cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
428
-
429
- progress(0.05, desc="Preparing LoRA state")
430
- if cache_path.exists():
431
- try:
432
- progress(0.20, desc="Loading cached fused state")
433
- state = load_file(str(cache_path))
434
- PENDING_LORA_KEY = key
435
- PENDING_LORA_STATE = state
436
- PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
437
- return PENDING_LORA_STATUS
438
- except Exception as e:
439
- print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
440
-
441
- entries = [
442
- (pose_lora_path, round(float(pose_strength), 2)),
443
- (general_lora_path, round(float(general_strength), 2)),
444
- (motion_lora_path, round(float(motion_strength), 2)),
445
- (dreamlay_lora_path, round(float(dreamlay_strength), 2)),
446
- (mself_lora_path, round(float(mself_strength), 2)),
447
- (dramatic_lora_path, round(float(dramatic_strength), 2)),
448
- (fluid_lora_path, round(float(fluid_strength), 2)),
449
- (liquid_lora_path, round(float(liquid_strength), 2)),
450
- (demopose_lora_path, round(float(demopose_strength), 2)),
451
- (voice_lora_path, round(float(voice_strength), 2)),
452
- (realism_lora_path, round(float(realism_strength), 2)),
453
- (transition_lora_path, round(float(transition_strength), 2)),
454
- ]
455
- loras_for_builder = [
456
- LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
457
- for path, strength in entries
458
- if path is not None and float(strength) != 0.0
459
- ]
460
 
461
- if not loras_for_builder:
462
- PENDING_LORA_KEY = None
463
- PENDING_LORA_STATE = None
464
- PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
465
- return PENDING_LORA_STATUS
466
 
467
- tmp_ledger = None
468
- new_transformer_cpu = None
469
- try:
470
- progress(0.35, desc="Building fused CPU transformer")
471
- tmp_ledger = pipeline.model_ledger.__class__(
472
- dtype=ledger.dtype,
473
- device=torch.device("cpu"),
474
- checkpoint_path=str(checkpoint_path),
475
- spatial_upsampler_path=str(spatial_upsampler_path),
476
- gemma_root_path=str(gemma_root),
477
- loras=tuple(loras_for_builder),
478
- quantization=getattr(ledger, "quantization", None),
479
- )
480
- new_transformer_cpu = tmp_ledger.transformer()
481
 
482
- progress(0.70, desc="Extracting fused state_dict")
483
- state = {
484
- k: v.detach().cpu().contiguous()
485
- for k, v in new_transformer_cpu.state_dict().items()
486
- }
487
- save_file(state, str(cache_path))
488
 
489
- PENDING_LORA_KEY = key
490
- PENDING_LORA_STATE = state
491
- PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
492
- return PENDING_LORA_STATUS
493
 
494
- except Exception as e:
495
- import traceback
496
- print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
497
- print(traceback.format_exc())
498
- PENDING_LORA_KEY = None
499
- PENDING_LORA_STATE = None
500
- PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
501
- return PENDING_LORA_STATUS
502
-
503
- finally:
504
- try:
505
- del new_transformer_cpu
506
- except Exception:
507
- pass
508
- try:
509
- del tmp_ledger
510
- except Exception:
511
- pass
512
- gc.collect()
513
 
 
 
 
 
514
 
515
- def apply_prepared_lora_state_to_pipeline():
516
- """
517
- Fast step: copy the already prepared CPU state into the live transformer.
518
- This is the only part that should remain near generation time.
519
- """
520
- global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
521
 
522
- if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
523
- print("[LoRA] No prepared LoRA state available; skipping.")
524
- return False
525
 
526
- if current_lora_key == PENDING_LORA_KEY:
527
- print("[LoRA] Prepared LoRA state already active; skipping.")
528
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
- existing_transformer = _transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  with torch.no_grad():
532
- missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
533
- if missing or unexpected:
534
- print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
 
 
 
535
 
536
- current_lora_key = PENDING_LORA_KEY
537
- print("[LoRA] Prepared LoRA state applied to the pipeline.")
538
- return True
 
539
 
540
  # ---- REPLACE PRELOAD BLOCK START ----
541
  # Preload all models for ZeroGPU tensor packing.
@@ -550,7 +523,6 @@ _orig_video_decoder_factory = ledger.video_decoder
550
  _orig_audio_encoder_factory = ledger.audio_encoder
551
  _orig_audio_decoder_factory = ledger.audio_decoder
552
  _orig_vocoder_factory = ledger.vocoder
553
- _orig_spatial_upsampler_factory = ledger.spatial_upsampler
554
  _orig_text_encoder_factory = ledger.text_encoder
555
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
556
 
@@ -561,7 +533,6 @@ _video_decoder = _orig_video_decoder_factory()
561
  _audio_encoder = _orig_audio_encoder_factory()
562
  _audio_decoder = _orig_audio_decoder_factory()
563
  _vocoder = _orig_vocoder_factory()
564
- _spatial_upsampler = _orig_spatial_upsampler_factory()
565
  _text_encoder = _orig_text_encoder_factory()
566
  _embeddings_processor = _orig_gemma_embeddings_factory()
567
 
@@ -573,7 +544,6 @@ ledger.video_decoder = lambda: _video_decoder
573
  ledger.audio_encoder = lambda: _audio_encoder
574
  ledger.audio_decoder = lambda: _audio_decoder
575
  ledger.vocoder = lambda: _vocoder
576
- ledger.spatial_upsampler = lambda: _spatial_upsampler
577
  ledger.text_encoder = lambda: _text_encoder
578
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
579
 
@@ -716,7 +686,11 @@ def generate_video(
716
 
717
  log_memory("before pipeline call")
718
 
719
- apply_prepared_lora_state_to_pipeline()
 
 
 
 
720
 
721
  video, audio = pipeline(
722
  prompt=prompt,
@@ -833,11 +807,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
833
  label="Transition strength",
834
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
835
  )
836
- prepare_lora_btn = gr.Button("Prepare / Load LoRA Cache", variant="secondary")
837
- lora_status = gr.Textbox(
838
- label="LoRA Cache Status",
839
- value="No LoRA state prepared yet.",
840
- interactive=False,
841
  )
842
 
843
  with gr.Column():
@@ -907,12 +876,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
907
  inputs=[first_image, last_image, high_res],
908
  outputs=[width, height],
909
  )
910
-
911
- prepare_lora_btn.click(
912
- fn=prepare_lora_cache,
913
- inputs=[pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength],
914
- outputs=[lora_status],
915
- )
916
 
917
  generate_btn.click(
918
  fn=generate_video,
 
252
  LORA_CACHE_DIR.mkdir(exist_ok=True)
253
  current_lora_key: str | None = None
254
 
 
 
 
 
255
  weights_dir = Path("weights")
256
  weights_dir.mkdir(exist_ok=True)
257
  checkpoint_path = hf_hub_download(
 
260
  local_dir=str(weights_dir),
261
  local_dir_use_symlinks=False,
262
  )
 
263
 
264
  print("[Gemma] Setting up abliterated Gemma text encoder...")
265
  MERGED_WEIGHTS = "/tmp/abliterated_gemma_merged.safetensors"
 
362
  # ----------------------------------------------------------------
363
 
364
  print(f"Checkpoint: {checkpoint_path}")
 
365
 
366
  # Initialize pipeline WITH text encoder and optional audio support
367
  # ---- Replace block (pipeline init) lines 275-281 ----
 
374
  )
375
  # ----------------------------------------------------------------
376
 
377
+ # Currently applied LoRA deltas stored so they can be undone before re-applying
378
+ _applied_lora_deltas: dict[str, torch.Tensor] = {}
379
+ _applied_lora_config: list[tuple[str, float]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
 
 
 
 
 
381
 
382
+ def _load_and_rename_lora_tensors(lora_path: str) -> dict[str, torch.Tensor]:
383
+ """Load LoRA tensors from disk and apply ComfyUI→LTX key renaming."""
384
+ tensors = {}
385
+ with safe_open(lora_path, framework="pt", device="cpu") as f:
386
+ for key in f.keys():
387
+ tensors[key] = f.get_tensor(key)
 
 
 
 
 
 
 
 
388
 
389
+ renamed = {}
390
+ for key, tensor in tensors.items():
391
+ new_key = key
392
+ for old_substr, new_substr in LTXV_LORA_COMFY_RENAMING_MAP.items():
393
+ new_key = new_key.replace(old_substr, new_substr)
394
+ renamed[new_key] = tensor
395
 
396
+ return renamed
 
 
 
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
+ def _compute_lora_deltas(lora_path: str, strength: float) -> dict[str, torch.Tensor]:
400
+ """Compute weight delta tensors for a single LoRA at given strength."""
401
+ tensors = _load_and_rename_lora_tensors(lora_path)
402
+ deltas = {}
403
 
404
+ # Collect all base keys that have a down component
405
+ base_keys = set()
406
+ for key in tensors:
407
+ for suffix in [".lora_down.weight", ".lora_A.weight"]:
408
+ if key.endswith(suffix):
409
+ base_keys.add(key[: -len(suffix)])
410
 
411
+ for base in base_keys:
412
+ down = tensors.get(base + ".lora_down.weight") or tensors.get(base + ".lora_A.weight")
413
+ up = tensors.get(base + ".lora_up.weight") or tensors.get(base + ".lora_B.weight")
414
 
415
+ if down is None or up is None:
416
+ continue
417
+
418
+ alpha_val = tensors.get(base + ".alpha")
419
+ scale = (alpha_val.item() / down.shape[0]) if alpha_val is not None else 1.0
420
+
421
+ down_f = down.float()
422
+ up_f = up.float()
423
+
424
+ if down_f.dim() == 2 and up_f.dim() == 2:
425
+ delta = up_f @ down_f
426
+ elif down_f.dim() == 4:
427
+ delta = (up_f.flatten(1) @ down_f.flatten(1)).view(
428
+ up_f.shape[0], down_f.shape[1], *up_f.shape[2:]
429
+ )
430
+ else:
431
+ print(f"[LoRA] Skipping {base}: unexpected dims down={down_f.dim()} up={up_f.dim()}")
432
+ continue
433
 
434
+ deltas[base + ".weight"] = (delta * strength * scale).to(torch.bfloat16)
435
+
436
+ return deltas
437
+
438
+
439
+ def apply_loras_to_transformer(
440
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
441
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
442
+ demopose_strength, voice_strength, realism_strength, transition_strength,
443
+ ):
444
+ global _applied_lora_deltas, _applied_lora_config
445
+
446
+ lora_configs = [
447
+ (pose_lora_path, round(float(pose_strength), 2)),
448
+ (general_lora_path, round(float(general_strength), 2)),
449
+ (motion_lora_path, round(float(motion_strength), 2)),
450
+ (dreamlay_lora_path, round(float(dreamlay_strength), 2)),
451
+ (mself_lora_path, round(float(mself_strength), 2)),
452
+ (dramatic_lora_path, round(float(dramatic_strength), 2)),
453
+ (fluid_lora_path, round(float(fluid_strength), 2)),
454
+ (liquid_lora_path, round(float(liquid_strength), 2)),
455
+ (demopose_lora_path, round(float(demopose_strength), 2)),
456
+ (voice_lora_path, round(float(voice_strength), 2)),
457
+ (realism_lora_path, round(float(realism_strength), 2)),
458
+ (transition_lora_path, round(float(transition_strength), 2)),
459
+ ]
460
+
461
+ # Skip if config hasn't changed since last application
462
+ if lora_configs == _applied_lora_config:
463
+ print("[LoRA] Config unchanged, skipping re-application.")
464
+ return
465
+
466
+ # Undo previously applied deltas
467
+ if _applied_lora_deltas:
468
+ print(f"[LoRA] Undoing {len(_applied_lora_deltas)} previously applied delta(s)...")
469
+ with torch.no_grad():
470
+ for name, param in _transformer.named_parameters():
471
+ if name in _applied_lora_deltas:
472
+ param.data -= _applied_lora_deltas[name].to(
473
+ device=param.device, dtype=param.dtype
474
+ )
475
+ _applied_lora_deltas = {}
476
+ gc.collect()
477
+
478
+ active = [(p, s) for p, s in lora_configs if p is not None and s != 0.0]
479
+ if not active:
480
+ print("[LoRA] No active LoRAs.")
481
+ _applied_lora_config = lora_configs
482
+ return
483
+
484
+ print(f"[LoRA] Computing deltas for {len(active)} active LoRA(s)...")
485
+ combined_deltas: dict[str, torch.Tensor] = {}
486
+ for lora_path, strength in active:
487
+ try:
488
+ deltas = _compute_lora_deltas(lora_path, strength)
489
+ for key, delta in deltas.items():
490
+ if key in combined_deltas:
491
+ combined_deltas[key] = combined_deltas[key] + delta
492
+ else:
493
+ combined_deltas[key] = delta
494
+ print(f"[LoRA] {Path(lora_path).name}: {len(deltas)} delta(s) at strength {strength}")
495
+ except Exception as e:
496
+ import traceback
497
+ print(f"[LoRA] Failed on {lora_path}: {e}\n{traceback.format_exc()}")
498
+
499
+ applied_count = 0
500
  with torch.no_grad():
501
+ for name, param in _transformer.named_parameters():
502
+ if name in combined_deltas:
503
+ param.data += combined_deltas[name].to(
504
+ device=param.device, dtype=param.dtype
505
+ )
506
+ applied_count += 1
507
 
508
+ _applied_lora_deltas = combined_deltas
509
+ _applied_lora_config = lora_configs
510
+ print(f"[LoRA] Applied {applied_count} weight delta(s) to live transformer.")
511
+ gc.collect()
512
 
513
  # ---- REPLACE PRELOAD BLOCK START ----
514
  # Preload all models for ZeroGPU tensor packing.
 
523
  _orig_audio_encoder_factory = ledger.audio_encoder
524
  _orig_audio_decoder_factory = ledger.audio_decoder
525
  _orig_vocoder_factory = ledger.vocoder
 
526
  _orig_text_encoder_factory = ledger.text_encoder
527
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
528
 
 
533
  _audio_encoder = _orig_audio_encoder_factory()
534
  _audio_decoder = _orig_audio_decoder_factory()
535
  _vocoder = _orig_vocoder_factory()
 
536
  _text_encoder = _orig_text_encoder_factory()
537
  _embeddings_processor = _orig_gemma_embeddings_factory()
538
 
 
544
  ledger.audio_encoder = lambda: _audio_encoder
545
  ledger.audio_decoder = lambda: _audio_decoder
546
  ledger.vocoder = lambda: _vocoder
 
547
  ledger.text_encoder = lambda: _text_encoder
548
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
549
 
 
686
 
687
  log_memory("before pipeline call")
688
 
689
+ apply_loras_to_transformer(
690
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
691
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
692
+ demopose_strength, voice_strength, realism_strength, transition_strength,
693
+ )
694
 
695
  video, audio = pipeline(
696
  prompt=prompt,
 
807
  label="Transition strength",
808
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
809
  )
 
 
 
 
 
810
  )
811
 
812
  with gr.Column():
 
876
  inputs=[first_image, last_image, high_res],
877
  outputs=[width, height],
878
  )
 
 
 
 
 
 
879
 
880
  generate_btn.click(
881
  fn=generate_video,