Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,176 +19,6 @@ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
-
MANUAL_PATCHES_STORE = {"diff": {}, "diff_b": {}}
|
| 23 |
-
|
| 24 |
-
def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
| 25 |
-
global MANUAL_PATCHES_STORE
|
| 26 |
-
MANUAL_PATCHES_STORE = {"diff": {}, "diff_b": {}} # Reset for each conversion
|
| 27 |
-
peft_compatible_state_dict = {}
|
| 28 |
-
unhandled_keys = []
|
| 29 |
-
|
| 30 |
-
original_keys_map_to_diffusers = {}
|
| 31 |
-
|
| 32 |
-
# Mapping based on ComfyUI's WanModel structure and PeftAdapterMixin logic
|
| 33 |
-
# This needs to map the original LoRA key naming to Diffusers' expected PEFT keys
|
| 34 |
-
# diffusion_model.blocks.0.self_attn.q.lora_down.weight -> transformer.blocks.0.attn1.to_q.lora_A.weight
|
| 35 |
-
# diffusion_model.blocks.0.ffn.0.lora_down.weight -> transformer.blocks.0.ffn.net.0.proj.lora_A.weight
|
| 36 |
-
# diffusion_model.text_embedding.0.lora_down.weight -> transformer.condition_embedder.text_embedder.linear_1.lora_A.weight (example)
|
| 37 |
-
|
| 38 |
-
# Strip "diffusion_model." and map
|
| 39 |
-
for k, v in state_dict.items():
|
| 40 |
-
original_k = k # Keep for logging/debugging
|
| 41 |
-
if k.startswith("diffusion_model."):
|
| 42 |
-
k_stripped = k[len("diffusion_model."):]
|
| 43 |
-
elif k.startswith("difusion_model."): # Handle potential typo
|
| 44 |
-
k_stripped = k[len("difusion_model."):]
|
| 45 |
-
logger.warning(f"Key '{original_k}' starts with 'difusion_model.' (potential typo), processing as 'diffusion_model.'.")
|
| 46 |
-
else:
|
| 47 |
-
unhandled_keys.append(original_k)
|
| 48 |
-
continue
|
| 49 |
-
|
| 50 |
-
# Handle .diff and .diff_b keys by storing them separately
|
| 51 |
-
if k_stripped.endswith(".diff"):
|
| 52 |
-
target_model_key = k_stripped[:-len(".diff")] + ".weight"
|
| 53 |
-
MANUAL_PATCHES_STORE["diff"][target_model_key] = v
|
| 54 |
-
continue
|
| 55 |
-
elif k_stripped.endswith(".diff_b"):
|
| 56 |
-
target_model_key = k_stripped[:-len(".diff_b")] + ".bias"
|
| 57 |
-
MANUAL_PATCHES_STORE["diff_b"][target_model_key] = v
|
| 58 |
-
continue
|
| 59 |
-
|
| 60 |
-
# Handle standard LoRA A/B matrices
|
| 61 |
-
if ".lora_down.weight" in k_stripped:
|
| 62 |
-
diffusers_key_base = k_stripped.replace(".lora_down.weight", "")
|
| 63 |
-
# Apply transformations similar to _convert_non_diffusers_wan_lora_to_diffusers from diffusers
|
| 64 |
-
# but adapt to the PEFT naming convention (lora_A/lora_B)
|
| 65 |
-
# This part needs careful mapping based on WanTransformer3DModel structure
|
| 66 |
-
|
| 67 |
-
# Example mappings (these need to be comprehensive for all layers)
|
| 68 |
-
if diffusers_key_base.startswith("blocks."):
|
| 69 |
-
parts = diffusers_key_base.split(".")
|
| 70 |
-
block_idx = parts[1]
|
| 71 |
-
attn_type = parts[2] # self_attn or cross_attn
|
| 72 |
-
proj_type = parts[3] # q, k, v, o
|
| 73 |
-
|
| 74 |
-
if attn_type == "self_attn":
|
| 75 |
-
diffusers_peft_key = f"transformer.blocks.{block_idx}.attn1.to_{proj_type}.lora_A.weight"
|
| 76 |
-
elif attn_type == "cross_attn":
|
| 77 |
-
# WanTransformer3DModel uses attn2 for cross-attention like features
|
| 78 |
-
diffusers_peft_key = f"transformer.blocks.{block_idx}.attn2.to_{proj_type}.lora_A.weight"
|
| 79 |
-
else: # ffn
|
| 80 |
-
ffn_idx = proj_type # "0" or "2"
|
| 81 |
-
diffusers_peft_key = f"transformer.blocks.{block_idx}.ffn.net.{ffn_idx}.proj.lora_A.weight"
|
| 82 |
-
elif diffusers_key_base.startswith("text_embedding."):
|
| 83 |
-
idx_map = {"0": "linear_1", "2": "linear_2"}
|
| 84 |
-
idx = diffusers_key_base.split(".")[1]
|
| 85 |
-
diffusers_peft_key = f"transformer.condition_embedder.text_embedder.{idx_map[idx]}.lora_A.weight"
|
| 86 |
-
elif diffusers_key_base.startswith("time_embedding."):
|
| 87 |
-
idx_map = {"0": "linear_1", "2": "linear_2"}
|
| 88 |
-
idx = diffusers_key_base.split(".")[1]
|
| 89 |
-
diffusers_peft_key = f"transformer.condition_embedder.time_embedder.{idx_map[idx]}.lora_A.weight"
|
| 90 |
-
elif diffusers_key_base.startswith("time_projection."): # Assuming '1' from your example
|
| 91 |
-
diffusers_peft_key = f"transformer.condition_embedder.time_proj.lora_A.weight"
|
| 92 |
-
elif diffusers_key_base.startswith("patch_embedding"):
|
| 93 |
-
# WanTransformer3DModel has 'patch_embedding' at the top level
|
| 94 |
-
diffusers_peft_key = f"transformer.patch_embedding.lora_A.weight" # This needs to match how PEFT would name it
|
| 95 |
-
elif diffusers_key_base.startswith("head.head"):
|
| 96 |
-
diffusers_peft_key = f"transformer.proj_out.lora_A.weight"
|
| 97 |
-
else:
|
| 98 |
-
unhandled_keys.append(original_k)
|
| 99 |
-
continue
|
| 100 |
-
|
| 101 |
-
peft_compatible_state_dict[diffusers_peft_key] = v
|
| 102 |
-
original_keys_map_to_diffusers[k_stripped] = diffusers_peft_key
|
| 103 |
-
|
| 104 |
-
elif ".lora_up.weight" in k_stripped:
|
| 105 |
-
# Find the corresponding lora_down key to determine the base name
|
| 106 |
-
down_key_stripped = k_stripped.replace(".lora_up.weight", ".lora_down.weight")
|
| 107 |
-
if down_key_stripped in original_keys_map_to_diffusers:
|
| 108 |
-
diffusers_peft_key_A = original_keys_map_to_diffusers[down_key_stripped]
|
| 109 |
-
diffusers_peft_key_B = diffusers_peft_key_A.replace(".lora_A.weight", ".lora_B.weight")
|
| 110 |
-
peft_compatible_state_dict[diffusers_peft_key_B] = v
|
| 111 |
-
else:
|
| 112 |
-
unhandled_keys.append(original_k)
|
| 113 |
-
elif not (k_stripped.endswith(".alpha") or k_stripped.endswith(".dora_scale")): # Alphas are handled by PEFT if lora_A/B present
|
| 114 |
-
unhandled_keys.append(original_k)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
if unhandled_keys:
|
| 118 |
-
logger.warning(f"Custom Wan LoRA Converter: Unhandled keys: {unhandled_keys}")
|
| 119 |
-
|
| 120 |
-
return peft_compatible_state_dict
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def apply_manual_diff_patches(pipe_model, patches_store, lora_strength=1.0):
|
| 124 |
-
if not hasattr(pipe_model, "transformer"):
|
| 125 |
-
logger.error("Pipeline model does not have a 'transformer' attribute to patch.")
|
| 126 |
-
return
|
| 127 |
-
|
| 128 |
-
transformer = pipe_model.transformer
|
| 129 |
-
changed_params_count = 0
|
| 130 |
-
|
| 131 |
-
for key_base, diff_tensor in patches_store.get("diff", {}).items():
|
| 132 |
-
# key_base is like "blocks.0.self_attn.q.weight"
|
| 133 |
-
# We need to prepend "transformer." to match diffusers internal naming
|
| 134 |
-
target_key_full = f"transformer.{key_base}"
|
| 135 |
-
try:
|
| 136 |
-
module_path_parts = target_key_full.split('.')
|
| 137 |
-
param_name = module_path_parts[-1]
|
| 138 |
-
module_path = ".".join(module_path_parts[:-1])
|
| 139 |
-
module = transformer
|
| 140 |
-
for part in module_path.split('.')[1:]: # Skip the first 'transformer'
|
| 141 |
-
module = getattr(module, part)
|
| 142 |
-
|
| 143 |
-
original_param = getattr(module, param_name)
|
| 144 |
-
if original_param.shape != diff_tensor.shape:
|
| 145 |
-
logger.warning(f"Shape mismatch for diff patch on {target_key_full}: model {original_param.shape}, lora {diff_tensor.shape}. Skipping.")
|
| 146 |
-
continue
|
| 147 |
-
|
| 148 |
-
with torch.no_grad():
|
| 149 |
-
scaled_diff = (lora_strength * diff_tensor.to(original_param.device, original_param.dtype))
|
| 150 |
-
original_param.data.add_(scaled_diff)
|
| 151 |
-
changed_params_count +=1
|
| 152 |
-
except AttributeError:
|
| 153 |
-
logger.warning(f"Could not find parameter {target_key_full} in transformer to apply diff patch.")
|
| 154 |
-
except Exception as e:
|
| 155 |
-
logger.error(f"Error applying diff patch to {target_key_full}: {e}")
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
for key_base, diff_b_tensor in patches_store.get("diff_b", {}).items():
|
| 159 |
-
# key_base is like "blocks.0.self_attn.q.bias"
|
| 160 |
-
target_key_full = f"transformer.{key_base}"
|
| 161 |
-
try:
|
| 162 |
-
module_path_parts = target_key_full.split('.')
|
| 163 |
-
param_name = module_path_parts[-1]
|
| 164 |
-
module_path = ".".join(module_path_parts[:-1])
|
| 165 |
-
module = transformer
|
| 166 |
-
for part in module_path.split('.')[1:]:
|
| 167 |
-
module = getattr(module, part)
|
| 168 |
-
|
| 169 |
-
original_param = getattr(module, param_name)
|
| 170 |
-
if original_param is None:
|
| 171 |
-
logger.warning(f"Bias parameter {target_key_full} is None in model. Skipping diff_b patch.")
|
| 172 |
-
continue
|
| 173 |
-
|
| 174 |
-
if original_param.shape != diff_b_tensor.shape:
|
| 175 |
-
logger.warning(f"Shape mismatch for diff_b patch on {target_key_full}: model {original_param.shape}, lora {diff_b_tensor.shape}. Skipping.")
|
| 176 |
-
continue
|
| 177 |
-
|
| 178 |
-
with torch.no_grad():
|
| 179 |
-
scaled_diff_b = (lora_strength * diff_b_tensor.to(original_param.device, original_param.dtype))
|
| 180 |
-
original_param.data.add_(scaled_diff_b)
|
| 181 |
-
changed_params_count +=1
|
| 182 |
-
except AttributeError:
|
| 183 |
-
logger.warning(f"Could not find parameter {target_key_full} in transformer to apply diff_b patch.")
|
| 184 |
-
except Exception as e:
|
| 185 |
-
logger.error(f"Error applying diff_b patch to {target_key_full}: {e}")
|
| 186 |
-
if changed_params_count > 0:
|
| 187 |
-
logger.info(f"Applied {changed_params_count} manual diff/diff_b patches.")
|
| 188 |
-
else:
|
| 189 |
-
logger.info("No manual diff/diff_b patches were applied.")
|
| 190 |
-
|
| 191 |
-
|
| 192 |
# --- Model Loading ---
|
| 193 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
| 194 |
vae = AutoencoderKLWan.from_pretrained(
|
|
@@ -214,26 +44,7 @@ logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
|
|
| 214 |
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
| 215 |
|
| 216 |
logger.info("Loading LoRA weights with custom converter...")
|
| 217 |
-
|
| 218 |
-
from safetensors.torch import load_file as load_safetensors
|
| 219 |
-
raw_lora_state_dict = load_safetensors(causvid_path)
|
| 220 |
-
|
| 221 |
-
# Now call our custom converter which will populate MANUAL_PATCHES_STORE
|
| 222 |
-
peft_state_dict = _custom_convert_non_diffusers_wan_lora_to_diffusers(raw_lora_state_dict)
|
| 223 |
-
|
| 224 |
-
# Load the LoRA A/B matrices using PEFT
|
| 225 |
-
if peft_state_dict:
|
| 226 |
-
pipe.load_lora_weights(
|
| 227 |
-
peft_state_dict,
|
| 228 |
-
adapter_name="causvid_lora"
|
| 229 |
-
)
|
| 230 |
-
logger.info("PEFT LoRA A/B weights loaded.")
|
| 231 |
-
else:
|
| 232 |
-
logger.warning("No PEFT-compatible LoRA weights found after conversion.")
|
| 233 |
-
|
| 234 |
-
# Apply manual diff_b and diff patches
|
| 235 |
-
apply_manual_diff_patches(pipe, MANUAL_PATCHES_STORE, lora_strength=1.0) # Assuming default strength 1.0
|
| 236 |
-
logger.info("Manual diff_b/diff patches applied.")
|
| 237 |
|
| 238 |
|
| 239 |
# --- Gradio Interface Function ---
|
|
|
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# --- Model Loading ---
|
| 23 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
| 24 |
vae = AutoencoderKLWan.from_pretrained(
|
|
|
|
| 44 |
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
| 45 |
|
| 46 |
logger.info("Loading LoRA weights with custom converter...")
|
| 47 |
+
pipe.load_lora_weights(causvid_path,adapter_name="causvid_lora")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
# --- Gradio Interface Function ---
|