Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -456,24 +456,44 @@ def load_lora_into_cache(lora_path: str) -> StateDict:
|
|
| 456 |
print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
|
| 457 |
return state_dict
|
| 458 |
|
| 459 |
-
def _rename_lora_keys_for_base_model(lora_sd: StateDict) -> StateDict:
|
| 460 |
"""
|
| 461 |
Rename LoRA state dict keys to match the base model's key format.
|
| 462 |
|
| 463 |
-
LoRA files
|
| 464 |
-
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
"""
|
| 471 |
renamed_sd = {}
|
| 472 |
for key, tensor in lora_sd.sd.items():
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
renamed_sd[new_key] = tensor
|
| 478 |
|
| 479 |
return StateDict(
|
|
@@ -507,14 +527,29 @@ def build_fused_state_dict(
|
|
| 507 |
if progress_callback:
|
| 508 |
progress_callback(0.1, "Loading LoRA state dicts into memory")
|
| 509 |
|
| 510 |
-
# Step 1: Load all LoRA state dicts
|
| 511 |
-
# and rename keys to match base model's key format
|
| 512 |
lora_sd_with_strengths = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
for lora_path, strength in lora_configs:
|
| 514 |
sd = load_lora_into_cache(lora_path)
|
| 515 |
-
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
if progress_callback:
|
| 520 |
progress_callback(0.3, "Extracting base transformer state dict")
|
|
@@ -577,6 +612,14 @@ def on_prepare_loras_click(
|
|
| 577 |
Only runs on button click, NOT on slider change.
|
| 578 |
"""
|
| 579 |
global current_lora_key, FUSED_CACHE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
# Compute the cache key for this combination of strengths
|
| 582 |
key, _ = _make_lora_key(
|
|
|
|
| 456 |
print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
|
| 457 |
return state_dict
|
| 458 |
|
| 459 |
+
def _rename_lora_keys_for_base_model(lora_sd: StateDict, base_keys: set[str]) -> StateDict:
|
| 460 |
"""
|
| 461 |
Rename LoRA state dict keys to match the base model's key format.
|
| 462 |
|
| 463 |
+
LoRA files from different sources use various key structures.
|
| 464 |
+
We need to normalize them to match the base model.
|
| 465 |
|
| 466 |
+
Common patterns:
|
| 467 |
+
- "diffusion_model.transformer.xxx" -> "xxx" (ComfyUI export)
|
| 468 |
+
- "diffusion_model.xxx" -> "xxx" (strip prefix)
|
| 469 |
+
- "transformer.xxx" -> "xxx" (strip module name)
|
| 470 |
"""
|
| 471 |
renamed_sd = {}
|
| 472 |
for key, tensor in lora_sd.sd.items():
|
| 473 |
+
new_key = key
|
| 474 |
+
|
| 475 |
+
# Strip "diffusion_model." prefix
|
| 476 |
+
if new_key.startswith("diffusion_model."):
|
| 477 |
+
new_key = new_key[len("diffusion_model."):]
|
| 478 |
+
|
| 479 |
+
# If key still doesn't match any base keys, try stripping "transformer." prefix
|
| 480 |
+
if new_key not in base_keys and new_key.startswith("transformer."):
|
| 481 |
+
new_key = new_key[len("transformer."):]
|
| 482 |
+
|
| 483 |
+
# If it's a LoRA key, try to match the base weight key
|
| 484 |
+
if ".lora_A.weight" in new_key or ".lora_B.weight" in new_key:
|
| 485 |
+
base_path = new_key.replace(".lora_A.weight", ".weight").replace(".lora_B.weight", ".weight")
|
| 486 |
+
if base_path not in base_keys:
|
| 487 |
+
parts = base_path.split(".")
|
| 488 |
+
if parts[0] == "transformer":
|
| 489 |
+
alt_path = ".".join(parts[1:])
|
| 490 |
+
if alt_path in base_keys:
|
| 491 |
+
new_key = new_key.replace(base_path, alt_path)
|
| 492 |
+
elif parts[0] == "layer" and len(parts) > 1:
|
| 493 |
+
alt_path = ".".join(parts[1:])
|
| 494 |
+
if alt_path in base_keys:
|
| 495 |
+
new_key = new_key.replace(base_path, alt_path)
|
| 496 |
+
|
| 497 |
renamed_sd[new_key] = tensor
|
| 498 |
|
| 499 |
return StateDict(
|
|
|
|
| 527 |
if progress_callback:
|
| 528 |
progress_callback(0.1, "Loading LoRA state dicts into memory")
|
| 529 |
|
| 530 |
+
# Step 1: Load all LoRA state dicts and rename keys to match base model
|
|
|
|
| 531 |
lora_sd_with_strengths = []
|
| 532 |
+
|
| 533 |
+
# Get base key set for matching
|
| 534 |
+
base_dict = base_transformer.state_dict()
|
| 535 |
+
base_key_set = set(base_dict.keys())
|
| 536 |
+
print(f"[LoRA DEBUG] Total base model keys: {len(base_key_set)}")
|
| 537 |
+
|
| 538 |
for lora_path, strength in lora_configs:
|
| 539 |
sd = load_lora_into_cache(lora_path)
|
| 540 |
+
sd_renamed = _rename_lora_keys_for_base_model(sd, base_key_set)
|
| 541 |
+
|
| 542 |
+
# Show before/after for first few keys
|
| 543 |
+
original_keys = list(sd.sd.keys())[:3]
|
| 544 |
+
renamed_keys = list(sd_renamed.sd.keys())[:3]
|
| 545 |
+
print(f"[LoRA DEBUG] Before rename: {original_keys}")
|
| 546 |
+
print(f"[LoRA DEBUG] After rename: {renamed_keys}")
|
| 547 |
+
|
| 548 |
lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
|
| 549 |
+
print(f"[LoRA] Loaded and renamed {len(sd_renamed.sd)} keys from {os.path.basename(lora_path)}")
|
| 550 |
+
|
| 551 |
+
# Debug: Check LoRA key matching
|
| 552 |
+
print(f"[LoRA DEBUG] Sample base keys: {list(base_key_set)[:10]}")
|
| 553 |
|
| 554 |
if progress_callback:
|
| 555 |
progress_callback(0.3, "Extracting base transformer state dict")
|
|
|
|
| 612 |
Only runs on button click, NOT on slider change.
|
| 613 |
"""
|
| 614 |
global current_lora_key, FUSED_CACHE
|
| 615 |
+
|
| 616 |
+
# Debug: Verify transformer consistency
|
| 617 |
+
ledger_transformer = ledger.transformer()
|
| 618 |
+
pipeline_transformer = pipeline.model_ledger.transformer()
|
| 619 |
+
print(f"[LoRA DEBUG] ledger.transformer() id: {id(ledger_transformer)}")
|
| 620 |
+
print(f"[LoRA DEBUG] pipeline.model_ledger.transformer() id: {id(pipeline_transformer)}")
|
| 621 |
+
print(f"[LoRA DEBUG] Same object? {ledger_transformer is pipeline_transformer}")
|
| 622 |
+
print(f"[LoRA DEBUG] _transformer id: {id(_transformer)}")
|
| 623 |
|
| 624 |
# Compute the cache key for this combination of strengths
|
| 625 |
key, _ = _make_lora_key(
|