dagloop5 commited on
Commit
817abe5
·
verified ·
1 Parent(s): 0b6d12e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -473,7 +473,7 @@ def prepare_lora_cache(
473
  spatial_upsampler_path=str(spatial_upsampler_path),
474
  gemma_root_path=str(gemma_root),
475
  loras=tuple(loras_for_builder),
476
- quantization=getattr(ledger, "quantization", None),
477
  )
478
  new_transformer_cpu = tmp_ledger.transformer()
479
 
@@ -525,7 +525,7 @@ def apply_prepared_lora_state_to_pipeline():
525
  print("[LoRA] Prepared LoRA state already active; skipping.")
526
  return True
527
 
528
- existing_transformer = _transformer
529
  with torch.no_grad():
530
  missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
531
  if missing or unexpected:
@@ -552,8 +552,7 @@ _orig_spatial_upsampler_factory = ledger.spatial_upsampler
552
  _orig_text_encoder_factory = ledger.text_encoder
553
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
554
 
555
- # Call the original factories once to create the cached instances we will serve by default.
556
- _transformer = _orig_transformer_factory()
557
  _video_encoder = _orig_video_encoder_factory()
558
  _video_decoder = _orig_video_decoder_factory()
559
  _audio_encoder = _orig_audio_encoder_factory()
@@ -563,9 +562,18 @@ _spatial_upsampler = _orig_spatial_upsampler_factory()
563
  _text_encoder = _orig_text_encoder_factory()
564
  _embeddings_processor = _orig_gemma_embeddings_factory()
565
 
566
- # Replace ledger methods with lightweight lambdas that return the cached instances.
567
- # We keep the original factories above so we can call them later to rebuild components.
568
- ledger.transformer = lambda: _transformer
 
 
 
 
 
 
 
 
 
569
  ledger.video_encoder = lambda: _video_encoder
570
  ledger.video_decoder = lambda: _video_decoder
571
  ledger.audio_encoder = lambda: _audio_encoder
@@ -575,7 +583,7 @@ ledger.spatial_upsampler = lambda: _spatial_upsampler
575
  ledger.text_encoder = lambda: _text_encoder
576
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
577
 
578
- print("All models preloaded (including Gemma text encoder and audio encoder)!")
579
  # ---- REPLACE PRELOAD BLOCK END ----
580
 
581
  print("=" * 80)
 
473
  spatial_upsampler_path=str(spatial_upsampler_path),
474
  gemma_root_path=str(gemma_root),
475
  loras=tuple(loras_for_builder),
476
+ quantization=None,
477
  )
478
  new_transformer_cpu = tmp_ledger.transformer()
479
 
 
525
  print("[LoRA] Prepared LoRA state already active; skipping.")
526
  return True
527
 
528
+ existing_transformer = get_transformer()
529
  with torch.no_grad():
530
  missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
531
  if missing or unexpected:
 
552
  _orig_text_encoder_factory = ledger.text_encoder
553
  _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
554
 
555
+ # Keep everything else cached as before.
 
556
  _video_encoder = _orig_video_encoder_factory()
557
  _video_decoder = _orig_video_decoder_factory()
558
  _audio_encoder = _orig_audio_encoder_factory()
 
562
  _text_encoder = _orig_text_encoder_factory()
563
  _embeddings_processor = _orig_gemma_embeddings_factory()
564
 
565
+ # Do NOT build the transformer here.
566
+ # Build it lazily only when generation or LoRA application actually needs it.
567
+ _transformer = None
568
+
569
+ def get_transformer():
570
+ global _transformer
571
+ if _transformer is None:
572
+ _transformer = _orig_transformer_factory()
573
+ return _transformer
574
+
575
+ # Replace ledger methods with lightweight getters.
576
+ ledger.transformer = get_transformer
577
  ledger.video_encoder = lambda: _video_encoder
578
  ledger.video_decoder = lambda: _video_decoder
579
  ledger.audio_encoder = lambda: _audio_encoder
 
583
  ledger.text_encoder = lambda: _text_encoder
584
  ledger.gemma_embeddings_processor = lambda: _embeddings_processor
585
 
586
+ print("All non-transformer models preloaded; transformer will be built lazily.")
587
  # ---- REPLACE PRELOAD BLOCK END ----
588
 
589
  print("=" * 80)