dagloop5 commited on
Commit
d2437a0
·
verified ·
1 Parent(s): d01a956

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -17
app.py CHANGED
@@ -370,17 +370,47 @@ pipeline = LTX23DistilledA2VPipeline(
370
  distilled_checkpoint_path=checkpoint_path,
371
  spatial_upsampler_path=spatial_upsampler_path,
372
  gemma_root=gemma_root,
373
- loras=[
374
- LoraPathStrengthAndSDOps(
375
- distilled_lora_path,
376
- 1.0,
377
- LTXV_LORA_COMFY_RENAMING_MAP,
378
- )
379
- ],
380
  quantization=QuantizationPolicy.fp8_cast(), # keep FP8 quantization unchanged
381
  )
382
  # ----------------------------------------------------------------
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float) -> tuple[str, str]:
385
  rp = round(float(pose_strength), 2)
386
  rg = round(float(general_strength), 2)
@@ -552,18 +582,15 @@ _orig_spatial_upsampler_factory = ledger.spatial_upsampler
552
  _orig_text_encoder_factory = ledger.text_encoder
553
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
554
 
555
- # Keep everything else cached as before.
556
- _video_encoder = _orig_video_encoder_factory()
557
- _video_decoder = _orig_video_decoder_factory()
558
- _audio_encoder = _orig_audio_encoder_factory()
559
- _audio_decoder = _orig_audio_decoder_factory()
560
- _vocoder = _orig_vocoder_factory()
561
- _spatial_upsampler = _orig_spatial_upsampler_factory()
562
- _text_encoder = _orig_text_encoder_factory()
563
- _embeddings_processor = _orig_gemma_embeddings_factory()
564
-
565
  # Call the original factories once to create the cached instances we will serve by default.
566
  _transformer = _orig_transformer_factory()
 
 
 
 
 
 
 
567
  _video_encoder = _orig_video_encoder_factory()
568
  _video_decoder = _orig_video_decoder_factory()
569
  _audio_encoder = _orig_audio_encoder_factory()
 
370
  distilled_checkpoint_path=checkpoint_path,
371
  spatial_upsampler_path=spatial_upsampler_path,
372
  gemma_root=gemma_root,
373
+ loras=[],
 
 
 
 
 
 
374
  quantization=QuantizationPolicy.fp8_cast(), # keep FP8 quantization unchanged
375
  )
376
  # ----------------------------------------------------------------
377
 
378
+ DISTILLED_DEFAULT_STATE: dict[str, torch.Tensor] | None = None
379
+
380
+ def prepare_distilled_default_state():
381
+ global DISTILLED_DEFAULT_STATE
382
+
383
+ if DISTILLED_DEFAULT_STATE is not None:
384
+ return
385
+
386
+ print("Preparing distilled default LoRA state on CPU...")
387
+ tmp_ledger = pipeline.model_ledger.__class__(
388
+ dtype=pipeline.model_ledger.dtype,
389
+ device=torch.device("cpu"),
390
+ checkpoint_path=str(checkpoint_path),
391
+ spatial_upsampler_path=str(spatial_upsampler_path),
392
+ gemma_root_path=str(gemma_root),
393
+ loras=(
394
+ LoraPathStrengthAndSDOps(
395
+ distilled_lora_path,
396
+ 1.0,
397
+ LTXV_LORA_COMFY_RENAMING_MAP,
398
+ ),
399
+ ),
400
+ quantization=None,
401
+ )
402
+
403
+ distilled_transformer = tmp_ledger.transformer()
404
+ DISTILLED_DEFAULT_STATE = {
405
+ k: v.detach().cpu().contiguous()
406
+ for k, v in distilled_transformer.state_dict().items()
407
+ }
408
+
409
+ del distilled_transformer
410
+ del tmp_ledger
411
+ gc.collect()
412
+ print("Distilled default LoRA state prepared.")
413
+
414
  def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float) -> tuple[str, str]:
415
  rp = round(float(pose_strength), 2)
416
  rg = round(float(general_strength), 2)
 
582
  _orig_text_encoder_factory = ledger.text_encoder
583
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
584
 
 
 
 
 
 
 
 
 
 
 
585
  # Call the original factories once to create the cached instances we will serve by default.
586
  _transformer = _orig_transformer_factory()
587
+ _transformer = _orig_transformer_factory()
588
+ if DISTILLED_DEFAULT_STATE is not None:
589
+ with torch.no_grad():
590
+ missing, unexpected = _transformer.load_state_dict(DISTILLED_DEFAULT_STATE, strict=False)
591
+ if missing or unexpected:
592
+ print(f"[Distilled default] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
593
+ print("[Distilled default] applied to transformer.")
594
  _video_encoder = _orig_video_encoder_factory()
595
  _video_decoder = _orig_video_decoder_factory()
596
  _audio_encoder = _orig_audio_encoder_factory()