dagloop5 commited on
Commit
eccc5a3
·
verified ·
1 Parent(s): b61b1e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -24
app.py CHANGED
@@ -276,7 +276,14 @@ PENDING_LORA_KEY: str | None = None
276
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
277
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
278
 
279
- checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-distilled.safetensors")
 
 
 
 
 
 
 
280
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
281
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
282
 
@@ -372,29 +379,16 @@ def prepare_lora_cache(
372
  new_transformer_cpu = None
373
  try:
374
  progress(0.35, desc="Building fused CPU transformer")
375
- tmp_ledger = ledger.with_loras(tuple(loras_for_builder))
376
-
377
- orig_tmp_target = getattr(tmp_ledger, "_target_device", None)
378
- orig_tmp_device = getattr(tmp_ledger, "device", None)
379
- try:
380
- tmp_ledger._target_device = lambda: torch.device("cpu")
381
- tmp_ledger.device = torch.device("cpu")
382
- new_transformer_cpu = tmp_ledger.transformer()
383
- finally:
384
- if orig_tmp_target is not None:
385
- tmp_ledger._target_device = orig_tmp_target
386
- else:
387
- try:
388
- delattr(tmp_ledger, "_target_device")
389
- except Exception:
390
- pass
391
- if orig_tmp_device is not None:
392
- tmp_ledger.device = orig_tmp_device
393
- else:
394
- try:
395
- delattr(tmp_ledger, "device")
396
- except Exception:
397
- pass
398
 
399
  progress(0.70, desc="Extracting fused state_dict")
400
  state = new_transformer_cpu.state_dict()
 
276
  PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
277
  PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
278
 
279
+ weights_dir = Path("weights")
280
+ weights_dir.mkdir(exist_ok=True)
281
+ checkpoint_path = hf_hub_download(
282
+ repo_id=LTX_MODEL_REPO,
283
+ filename="ltx-2.3-22b-distilled.safetensors",
284
+ local_dir=str(weights_dir),
285
+ local_dir_use_symlinks=False,
286
+ )
287
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.1.safetensors")
288
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
289
 
 
379
  new_transformer_cpu = None
380
  try:
381
  progress(0.35, desc="Building fused CPU transformer")
382
+ tmp_ledger = pipeline.model_ledger.__class__(
383
+ dtype=ledger.dtype,
384
+ device=torch.device("cpu"),
385
+ checkpoint_path=str(checkpoint_path),
386
+ spatial_upsampler_path=str(spatial_upsampler_path),
387
+ gemma_root_path=str(gemma_root),
388
+ loras=tuple(loras_for_builder),
389
+ quantization=getattr(ledger, "quantization", None),
390
+ )
391
+ new_transformer_cpu = tmp_ledger.transformer()
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  progress(0.70, desc="Extracting fused state_dict")
394
  state = new_transformer_cpu.state_dict()