Spaces:
Running
Running
File size: 8,024 Bytes
e0336bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# convert_lora_i2v_to_fc.py
import torch
import safetensors.torch
import safetensors # Need this for safe_open
import argparse
import os
import re # Regular expressions might be useful for more complex key parsing if needed
# !!! IMPORTANT: Updated based on the output of analyze_wan_models.py !!!
# The base layer name identified with shape mismatch.
# Check your LoRA file's keys if they use a different prefix (e.g., 'transformer.')
# Assuming the base name identified in LoRA keys matches this.
BASE_LAYERS_TO_SKIP_LORA = {
"patch_embedding", # The layer name from the analysis output
# Add other layers here ONLY if the analysis revealed more mismatches
}
# !!! END IMPORTANT SECTION !!!
def get_base_layer_name(lora_key: str, prefixes = ["lora_transformer_", "lora_unet_"]):
"""
Attempts to extract the base model layer name from a LoRA key.
Handles common prefixes and suffixes. Adjust prefixes if needed.
Example: "lora_transformer_patch_embedding_down.weight" -> "patch_embedding"
"lora_transformer_blocks_0_attn_qkv.alpha" -> "blocks.0.attn.qkv"
Args:
lora_key (str): The key from the LoRA state dictionary.
prefixes (list[str]): A list of potential prefixes used in LoRA keys.
Returns:
str: The inferred base model layer name.
"""
cleaned_key = lora_key
# Remove known prefixes
for prefix in prefixes:
if cleaned_key.startswith(prefix):
cleaned_key = cleaned_key[len(prefix):]
break # Assume only one prefix matches
# Remove known suffixes
# Order matters slightly if one suffix is part of another; list longer ones first if needed
known_suffixes = [
".lora_up.weight",
".lora_down.weight",
"_lora_up.weight", # Include underscore variants just in case
"_lora_down.weight",
".alpha"
]
for suffix in known_suffixes:
if cleaned_key.endswith(suffix):
cleaned_key = cleaned_key[:-len(suffix)]
break
# Replace underscores used by some training scripts with periods for consistency
# if the original model uses periods (like typical PyTorch modules).
# Adjust this logic if the base model itself uses underscores extensively.
cleaned_key = cleaned_key.replace("_", ".")
# Specific fix for the target layer if prefix/suffix removal was incomplete or ambiguous
# This is somewhat heuristic and might need adjustment based on exact LoRA key naming.
if cleaned_key.startswith("patch.embedding"): # Handle case where prefix removal was incomplete
# Map potential variants back to the canonical name found in analysis
cleaned_key = "patch_embedding"
elif cleaned_key == "patch.embedding.weight": # If suffix removal left .weight attached somehow
cleaned_key = "patch_embedding"
# Add elif clauses here if other specific key mappings are needed
return cleaned_key
def convert_lora(source_lora_path: str, target_lora_path: str):
"""
Converts an i2v_14B LoRA to be compatible with i2v_14B_FC by
removing LoRA weights associated with layers that have incompatible shapes.
Args:
source_lora_path (str): Path to the input LoRA file (.safetensors).
target_lora_path (str): Path to save the converted LoRA file (.safetensors).
"""
print(f"Loading source LoRA from: {source_lora_path}")
if not os.path.exists(source_lora_path):
print(f"Error: Source file not found: {source_lora_path}")
return
try:
# Load tensors and metadata using safe_open for better handling
source_lora_state_dict = {}
metadata = {}
with safetensors.safe_open(source_lora_path, framework="pt", device="cpu") as f:
metadata = f.metadata() # Get metadata if it exists
if metadata is None: # Ensure metadata is a dict even if empty
metadata = {}
for key in f.keys():
source_lora_state_dict[key] = f.get_tensor(key) # Load tensors
print(f"Successfully loaded {len(source_lora_state_dict)} tensors.")
if metadata:
print(f"Found metadata: {metadata}")
else:
print("No metadata found.")
except Exception as e:
print(f"Error loading LoRA file: {e}")
import traceback
traceback.print_exc()
return
target_lora_state_dict = {}
skipped_keys = []
kept_keys = []
base_name_map = {} # Store mapping for reporting
print(f"\nConverting LoRA weights...")
print(f"Will skip LoRA weights targeting these base layers: {BASE_LAYERS_TO_SKIP_LORA}")
# Iterate through the loaded tensors
for key, tensor in source_lora_state_dict.items():
# Use the helper function to extract the base layer name
base_layer_name = get_base_layer_name(key)
base_name_map[key] = base_layer_name # Store for reporting purposes
# Check if the identified base layer name should be skipped
if base_layer_name in BASE_LAYERS_TO_SKIP_LORA:
skipped_keys.append(key)
else:
# Keep the tensor if its base layer is not in the skip list
target_lora_state_dict[key] = tensor
kept_keys.append(key)
# --- Reporting ---
print(f"\nConversion Summary:")
print(f" - Total Tensors in Source: {len(source_lora_state_dict)}")
print(f" - Kept {len(kept_keys)} LoRA weight tensors.")
print(f" - Skipped {len(skipped_keys)} LoRA weight tensors (due to incompatible base layer shape):")
if skipped_keys:
max_print = 15 # Show a few more skipped keys if desired
skipped_sorted = sorted(skipped_keys) # Sort for consistent output order
for i, key in enumerate(skipped_sorted):
base_name = base_name_map.get(key, "N/A") # Get the identified base name
print(f" - {key} (Base Layer Identified: {base_name})")
if i >= max_print -1 and len(skipped_keys) > max_print:
print(f" ... and {len(skipped_keys) - max_print} more.")
break
else:
print(" None")
# --- Saving ---
print(f"\nSaving converted LoRA ({len(target_lora_state_dict)} tensors) to: {target_lora_path}")
try:
# Save the filtered state dictionary with the original metadata
safetensors.torch.save_file(target_lora_state_dict, target_lora_path, metadata=metadata)
print("Conversion successful!")
except Exception as e:
print(f"Error saving converted LoRA file: {e}")
if __name__ == "__main__":
# Setup argument parser
parser = argparse.ArgumentParser(
description="Convert Wan i2v_14B LoRA to i2v_14B_FC LoRA by removing incompatible patch_embedding weights.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("source_lora", type=str, help="Path to the source i2v_14B LoRA file (.safetensors).")
parser.add_argument("target_lora", type=str, help="Path to save the converted i2v_14B_FC LoRA file (.safetensors).")
# Parse arguments
args = parser.parse_args()
# --- Input Validation ---
if not os.path.exists(args.source_lora):
print(f"Error: Source LoRA file not found at '{args.source_lora}'")
elif not args.source_lora.lower().endswith(".safetensors"):
print(f"Warning: Source file '{args.source_lora}' does not have a .safetensors extension.")
elif args.source_lora == args.target_lora:
print(f"Error: Source and target paths cannot be the same ('{args.source_lora}'). Choose a different target path.")
elif os.path.exists(args.target_lora):
print(f"Warning: Target file '{args.target_lora}' already exists and will be overwritten.")
# Optionally add a --force flag or prompt user here
convert_lora(args.source_lora, args.target_lora)
else:
# Run the conversion if basic checks pass
convert_lora(args.source_lora, args.target_lora) |