Update app.py
Browse files
app.py
CHANGED
|
@@ -473,7 +473,7 @@ def prepare_lora_cache(
|
|
| 473 |
spatial_upsampler_path=str(spatial_upsampler_path),
|
| 474 |
gemma_root_path=str(gemma_root),
|
| 475 |
loras=tuple(loras_for_builder),
|
| 476 |
-
quantization=
|
| 477 |
)
|
| 478 |
new_transformer_cpu = tmp_ledger.transformer()
|
| 479 |
|
|
@@ -525,7 +525,7 @@ def apply_prepared_lora_state_to_pipeline():
|
|
| 525 |
print("[LoRA] Prepared LoRA state already active; skipping.")
|
| 526 |
return True
|
| 527 |
|
| 528 |
-
existing_transformer =
|
| 529 |
with torch.no_grad():
|
| 530 |
missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
|
| 531 |
if missing or unexpected:
|
|
@@ -552,8 +552,7 @@ _orig_spatial_upsampler_factory = ledger.spatial_upsampler
|
|
| 552 |
_orig_text_encoder_factory = ledger.text_encoder
|
| 553 |
_orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
|
| 554 |
|
| 555 |
-
#
|
| 556 |
-
_transformer = _orig_transformer_factory()
|
| 557 |
_video_encoder = _orig_video_encoder_factory()
|
| 558 |
_video_decoder = _orig_video_decoder_factory()
|
| 559 |
_audio_encoder = _orig_audio_encoder_factory()
|
|
@@ -563,9 +562,18 @@ _spatial_upsampler = _orig_spatial_upsampler_factory()
|
|
| 563 |
_text_encoder = _orig_text_encoder_factory()
|
| 564 |
_embeddings_processor = _orig_gemma_embeddings_factory()
|
| 565 |
|
| 566 |
-
#
|
| 567 |
-
#
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
ledger.video_encoder = lambda: _video_encoder
|
| 570 |
ledger.video_decoder = lambda: _video_decoder
|
| 571 |
ledger.audio_encoder = lambda: _audio_encoder
|
|
@@ -575,7 +583,7 @@ ledger.spatial_upsampler = lambda: _spatial_upsampler
|
|
| 575 |
ledger.text_encoder = lambda: _text_encoder
|
| 576 |
ledger.gemma_embeddings_processor = lambda: _embeddings_processor
|
| 577 |
|
| 578 |
-
print("All models preloaded
|
| 579 |
# ---- REPLACE PRELOAD BLOCK END ----
|
| 580 |
|
| 581 |
print("=" * 80)
|
|
|
|
| 473 |
spatial_upsampler_path=str(spatial_upsampler_path),
|
| 474 |
gemma_root_path=str(gemma_root),
|
| 475 |
loras=tuple(loras_for_builder),
|
| 476 |
+
quantization=None,
|
| 477 |
)
|
| 478 |
new_transformer_cpu = tmp_ledger.transformer()
|
| 479 |
|
|
|
|
| 525 |
print("[LoRA] Prepared LoRA state already active; skipping.")
|
| 526 |
return True
|
| 527 |
|
| 528 |
+
existing_transformer = get_transformer()
|
| 529 |
with torch.no_grad():
|
| 530 |
missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
|
| 531 |
if missing or unexpected:
|
|
|
|
| 552 |
_orig_text_encoder_factory = ledger.text_encoder
|
| 553 |
_orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
|
| 554 |
|
| 555 |
+
# Keep everything else cached as before.
|
|
|
|
| 556 |
_video_encoder = _orig_video_encoder_factory()
|
| 557 |
_video_decoder = _orig_video_decoder_factory()
|
| 558 |
_audio_encoder = _orig_audio_encoder_factory()
|
|
|
|
| 562 |
_text_encoder = _orig_text_encoder_factory()
|
| 563 |
_embeddings_processor = _orig_gemma_embeddings_factory()
|
| 564 |
|
| 565 |
+
# Do NOT build the transformer here.
|
| 566 |
+
# Build it lazily only when generation or LoRA application actually needs it.
|
| 567 |
+
_transformer = None
|
| 568 |
+
|
| 569 |
+
def get_transformer():
|
| 570 |
+
global _transformer
|
| 571 |
+
if _transformer is None:
|
| 572 |
+
_transformer = _orig_transformer_factory()
|
| 573 |
+
return _transformer
|
| 574 |
+
|
| 575 |
+
# Replace ledger methods with lightweight getters.
|
| 576 |
+
ledger.transformer = get_transformer
|
| 577 |
ledger.video_encoder = lambda: _video_encoder
|
| 578 |
ledger.video_decoder = lambda: _video_decoder
|
| 579 |
ledger.audio_encoder = lambda: _audio_encoder
|
|
|
|
| 583 |
ledger.text_encoder = lambda: _text_encoder
|
| 584 |
ledger.gemma_embeddings_processor = lambda: _embeddings_processor
|
| 585 |
|
| 586 |
+
print("All non-transformer models preloaded; transformer will be built lazily.")
|
| 587 |
# ---- REPLACE PRELOAD BLOCK END ----
|
| 588 |
|
| 589 |
print("=" * 80)
|