dagloop5 commited on
Commit
dd5f993
·
verified ·
1 Parent(s): 5490bd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -15
app.py CHANGED
@@ -456,24 +456,44 @@ def load_lora_into_cache(lora_path: str) -> StateDict:
456
  print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
457
  return state_dict
458
 
459
- def _rename_lora_keys_for_base_model(lora_sd: StateDict) -> StateDict:
460
  """
461
  Rename LoRA state dict keys to match the base model's key format.
462
 
463
- LoRA files typically use ComfyUI-style keys like:
464
- "diffusion_model.layer1.blocks.0.lora_A.weight"
465
 
466
- But the base model uses keys like:
467
- "layer1.blocks.0.lora_A.weight"
468
-
469
- This function strips the "diffusion_model." prefix to match them.
470
  """
471
  renamed_sd = {}
472
  for key, tensor in lora_sd.sd.items():
473
- if key.startswith("diffusion_model."):
474
- new_key = key[len("diffusion_model."):]
475
- else:
476
- new_key = key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  renamed_sd[new_key] = tensor
478
 
479
  return StateDict(
@@ -507,14 +527,29 @@ def build_fused_state_dict(
507
  if progress_callback:
508
  progress_callback(0.1, "Loading LoRA state dicts into memory")
509
 
510
- # Step 1: Load all LoRA state dicts (uses cache after first load)
511
- # and rename keys to match base model's key format
512
  lora_sd_with_strengths = []
 
 
 
 
 
 
513
  for lora_path, strength in lora_configs:
514
  sd = load_lora_into_cache(lora_path)
515
- # Apply key renaming to strip "diffusion_model." prefix
516
- sd_renamed = _rename_lora_keys_for_base_model(sd)
 
 
 
 
 
 
517
  lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
 
 
 
 
518
 
519
  if progress_callback:
520
  progress_callback(0.3, "Extracting base transformer state dict")
@@ -577,6 +612,14 @@ def on_prepare_loras_click(
577
  Only runs on button click, NOT on slider change.
578
  """
579
  global current_lora_key, FUSED_CACHE
 
 
 
 
 
 
 
 
580
 
581
  # Compute the cache key for this combination of strengths
582
  key, _ = _make_lora_key(
 
456
  print(f"[LoRA] Cached {len(tensors)} tensors from {os.path.basename(lora_path)}")
457
  return state_dict
458
 
459
+ def _rename_lora_keys_for_base_model(lora_sd: StateDict, base_keys: set[str]) -> StateDict:
460
  """
461
  Rename LoRA state dict keys to match the base model's key format.
462
 
463
+ LoRA files from different sources use various key structures.
464
+ We need to normalize them to match the base model.
465
 
466
+ Common patterns:
467
+ - "diffusion_model.transformer.xxx" -> "xxx" (ComfyUI export)
468
+ - "diffusion_model.xxx" -> "xxx" (strip prefix)
469
+ - "transformer.xxx" -> "xxx" (strip module name)
470
  """
471
  renamed_sd = {}
472
  for key, tensor in lora_sd.sd.items():
473
+ new_key = key
474
+
475
+ # Strip "diffusion_model." prefix
476
+ if new_key.startswith("diffusion_model."):
477
+ new_key = new_key[len("diffusion_model."):]
478
+
479
+ # If key still doesn't match any base keys, try stripping "transformer." prefix
480
+ if new_key not in base_keys and new_key.startswith("transformer."):
481
+ new_key = new_key[len("transformer."):]
482
+
483
+ # If it's a LoRA key, try to match the base weight key
484
+ if ".lora_A.weight" in new_key or ".lora_B.weight" in new_key:
485
+ base_path = new_key.replace(".lora_A.weight", ".weight").replace(".lora_B.weight", ".weight")
486
+ if base_path not in base_keys:
487
+ parts = base_path.split(".")
488
+ if parts[0] == "transformer":
489
+ alt_path = ".".join(parts[1:])
490
+ if alt_path in base_keys:
491
+ new_key = new_key.replace(base_path, alt_path)
492
+ elif parts[0] == "layer" and len(parts) > 1:
493
+ alt_path = ".".join(parts[1:])
494
+ if alt_path in base_keys:
495
+ new_key = new_key.replace(base_path, alt_path)
496
+
497
  renamed_sd[new_key] = tensor
498
 
499
  return StateDict(
 
527
  if progress_callback:
528
  progress_callback(0.1, "Loading LoRA state dicts into memory")
529
 
530
+ # Step 1: Load all LoRA state dicts and rename keys to match base model
 
531
  lora_sd_with_strengths = []
532
+
533
+ # Get base key set for matching
534
+ base_dict = base_transformer.state_dict()
535
+ base_key_set = set(base_dict.keys())
536
+ print(f"[LoRA DEBUG] Total base model keys: {len(base_key_set)}")
537
+
538
  for lora_path, strength in lora_configs:
539
  sd = load_lora_into_cache(lora_path)
540
+ sd_renamed = _rename_lora_keys_for_base_model(sd, base_key_set)
541
+
542
+ # Show before/after for first few keys
543
+ original_keys = list(sd.sd.keys())[:3]
544
+ renamed_keys = list(sd_renamed.sd.keys())[:3]
545
+ print(f"[LoRA DEBUG] Before rename: {original_keys}")
546
+ print(f"[LoRA DEBUG] After rename: {renamed_keys}")
547
+
548
  lora_sd_with_strengths.append(LoraStateDictWithStrength(sd_renamed, float(strength)))
549
+ print(f"[LoRA] Loaded and renamed {len(sd_renamed.sd)} keys from {os.path.basename(lora_path)}")
550
+
551
+ # Debug: Check LoRA key matching
552
+ print(f"[LoRA DEBUG] Sample base keys: {list(base_key_set)[:10]}")
553
 
554
  if progress_callback:
555
  progress_callback(0.3, "Extracting base transformer state dict")
 
612
  Only runs on button click, NOT on slider change.
613
  """
614
  global current_lora_key, FUSED_CACHE
615
+
616
+ # Debug: Verify transformer consistency
617
+ ledger_transformer = ledger.transformer()
618
+ pipeline_transformer = pipeline.model_ledger.transformer()
619
+ print(f"[LoRA DEBUG] ledger.transformer() id: {id(ledger_transformer)}")
620
+ print(f"[LoRA DEBUG] pipeline.model_ledger.transformer() id: {id(pipeline_transformer)}")
621
+ print(f"[LoRA DEBUG] Same object? {ledger_transformer is pipeline_transformer}")
622
+ print(f"[LoRA DEBUG] _transformer id: {id(_transformer)}")
623
 
624
  # Compute the cache key for this combination of strengths
625
  key, _ = _make_lora_key(