dagloop5 commited on
Commit
8fc30ea
·
verified ·
1 Parent(s): ea4b995

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -279,7 +279,14 @@ PENDING_LORA_KEY: str | None = None
279
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
280
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
281
 
282
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
 
 
 
 
 
 
 
283
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
284
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
285
 
@@ -375,29 +382,16 @@ def prepare_lora_cache(
375
  new_transformer_cpu = None
376
  try:
377
  progress(0.35, desc="Building fused CPU transformer")
378
- tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
379
-
380
- orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
381
- orig_tmp_device = getattr(tmp_ledger, "device", None)
382
- try:
383
- tmp_ledger._target_device = lambda: torch.device("cpu")
384
- tmp_ledger.device = torch.device("cpu")
385
- new_transformer_cpu = tmp_ledger.transformer()
386
- finally:
387
- if orig_tmp_target is not None:
388
- tmp_ledger._target_device = orig_tmp_target
389
- else:
390
- try:
391
- delattr(tmp_ledger, "_target_device")
392
- except Exception:
393
- pass
394
- if orig_tmp_device is not None:
395
- tmp_ledger.device = orig_tmp_device
396
- else:
397
- try:
398
- delattr(tmp_ledger, "device")
399
- except Exception:
400
- pass
401
 
402
  progress(0.70, desc="Extracting fused state_dict")
403
  state = new_transformer_cpu.state_dict()
 
279
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
280
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
281
 
282
+ weights_dir = Path("weights")
283
+ weights_dir.mkdir(exist_ok=True)
284
+ checkpoint_path = hf_hub_download(
285
+ repo_id=LTX_MODEL_REPO,
286
+ filename="ltx-2.3-22b-distilled.safetensors",
287
+ local_dir=str(weights_dir),
288
+ local_dir_use_symlinks=False,
289
+ )
290
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
291
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
292
 
 
382
  new_transformer_cpu = None
383
  try:
384
  progress(0.35, desc="Building fused CPU transformer")
385
+ tmp_ledger = pipeline.model_ledger.__class__(
386
+ dtype=ledger.dtype,
387
+ device=torch.device("cpu"),
388
+ checkpoint_path=str(checkpoint_path),
389
+ spatial_upsampler_path=str(spatial_upsampler_path),
390
+ gemma_root_path=str(gemma_root),
391
+ loras=tuple(loras_for_builder),
392
+ quantization=getattr(ledger, "quantization", None),
393
+ )
394
+ new_transformer_cpu = tmp_ledger.transformer()
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  progress(0.70, desc="Extracting fused state_dict")
397
  state = new_transformer_cpu.state_dict()