Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -192,7 +192,7 @@ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
| 192 |
return final_peft_state_dict
|
| 193 |
|
| 194 |
|
| 195 |
-
def apply_manual_diff_patches(pipe_model, patches):
|
| 196 |
"""
|
| 197 |
Manually applies diff_b/diff patches to the model.
|
| 198 |
Assumes PEFT LoRA layers have already been loaded.
|
|
@@ -204,87 +204,95 @@ def apply_manual_diff_patches(pipe_model, patches):
|
|
| 204 |
logger.info(f"Applying {len(patches)} manual diff patches...")
|
| 205 |
patched_keys_count = 0
|
| 206 |
unpatched_keys_count = 0
|
|
|
|
| 207 |
|
| 208 |
for key, diff_tensor in patches.items():
|
| 209 |
try:
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
# Navigate to the parent module
|
| 214 |
-
#
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
else:
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# If PEFT wrapped it, the actual nn.Linear or nn.LayerNorm is in `base_layer`
|
| 232 |
-
if hasattr(target_layer, "base_layer") and isinstance(target_layer.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
|
| 233 |
-
layer_to_modify = target_layer.base_layer
|
| 234 |
-
else:
|
| 235 |
-
layer_to_modify = target_layer
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
unpatched_keys_count +=1
|
| 240 |
continue
|
| 241 |
|
| 242 |
-
original_param = getattr(layer_to_modify, param_name)
|
| 243 |
-
|
| 244 |
-
if original_param is None and param_name == "bias":
|
| 245 |
-
# If bias is None (e.g., LayerNorm with elementwise_affine=False, or Linear(bias=False)),
|
| 246 |
-
# we might need to initialize it if the diff expects to add to it.
|
| 247 |
-
# For Linear layers, if bias was False, it should remain False unless LoRA intends to add one.
|
| 248 |
-
# For LayerNorm, if elementwise_affine was False, adding a bias diff means it becomes affine.
|
| 249 |
-
if isinstance(layer_to_modify, torch.nn.Linear):
|
| 250 |
-
if layer_to_modify.bias is None: # Check if bias was intentionally None
|
| 251 |
-
logger.warning(f"Original layer {layer_to_modify} for key '{key}' has no bias. Creating one to apply diff_b. This might be unintended if bias=False was set.")
|
| 252 |
-
layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
|
| 253 |
-
original_param = layer_to_modify.bias
|
| 254 |
-
else: # Should not happen if original_param was None but layer_to_modify.bias isn't
|
| 255 |
-
pass
|
| 256 |
-
elif isinstance(layer_to_modify, torch.nn.LayerNorm):
|
| 257 |
-
if not layer_to_modify.elementwise_affine:
|
| 258 |
-
logger.warning(f"LayerNorm {layer_to_modify} for key '{key}' was not elementwise_affine. Applying bias diff will make it effectively affine for bias.")
|
| 259 |
-
# LayerNorm bias is initialized to zeros if elementwise_affine is True
|
| 260 |
-
layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
|
| 261 |
-
original_param = layer_to_modify.bias
|
| 262 |
-
# Also need to ensure weight exists if a weight diff is applied later
|
| 263 |
-
if param_name == "bias" and not hasattr(layer_to_modify, "weight"):
|
| 264 |
-
layer_to_modify.weight = torch.nn.Parameter(torch.ones_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype)) # Norm weights init to 1
|
| 265 |
|
| 266 |
if original_param is not None:
|
| 267 |
if original_param.shape != diff_tensor.shape:
|
| 268 |
-
|
| 269 |
-
unpatched_keys_count +=1
|
| 270 |
continue
|
| 271 |
with torch.no_grad():
|
| 272 |
original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
|
| 273 |
-
logger.info(f"Successfully applied diff to '{key}'")
|
| 274 |
-
patched_keys_count +=1
|
| 275 |
else:
|
| 276 |
-
|
| 277 |
-
unpatched_keys_count +=1
|
| 278 |
-
|
| 279 |
|
| 280 |
except AttributeError as e:
|
| 281 |
-
|
| 282 |
-
unpatched_keys_count +=1
|
| 283 |
except Exception as e:
|
| 284 |
-
|
| 285 |
-
unpatched_keys_count +=1
|
| 286 |
-
logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
# --- Model Loading ---
|
| 290 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
|
@@ -411,6 +419,7 @@ with gr.Blocks() as demo:
|
|
| 411 |
width_input,
|
| 412 |
num_frames_input,
|
| 413 |
guidance_scale_input,
|
|
|
|
| 414 |
fps_input
|
| 415 |
],
|
| 416 |
outputs=video_output
|
|
|
|
| 192 |
return final_peft_state_dict
|
| 193 |
|
| 194 |
|
| 195 |
+
def apply_manual_diff_patches(pipe_model: torch.nn.Module, patches: Dict[str, torch.Tensor]):
|
| 196 |
"""
|
| 197 |
Manually applies diff_b/diff patches to the model.
|
| 198 |
Assumes PEFT LoRA layers have already been loaded.
|
|
|
|
| 204 |
logger.info(f"Applying {len(patches)} manual diff patches...")
|
| 205 |
patched_keys_count = 0
|
| 206 |
unpatched_keys_count = 0
|
| 207 |
+
skipped_keys_details = []
|
| 208 |
|
| 209 |
for key, diff_tensor in patches.items():
|
| 210 |
try:
|
| 211 |
+
# key is like "transformer.blocks.0.attn1.to_q.bias"
|
| 212 |
+
current_module = pipe_model # Starts from pipe.transformer
|
| 213 |
+
path_parts = key.split('.')[1:] # Remove "transformer." prefix for getattr navigation
|
| 214 |
+
# e.g., ["blocks", "0", "attn1", "to_q", "bias"]
|
| 215 |
|
| 216 |
+
# Navigate to the parent module of the parameter
|
| 217 |
+
# Example: for "blocks.0.attn1.to_q.bias", parent_module_path is "blocks.0.attn1.to_q"
|
| 218 |
+
parent_module_path = path_parts[:-1]
|
| 219 |
+
param_name_to_patch = path_parts[-1] # "bias" or "weight"
|
| 220 |
+
|
| 221 |
+
for part in parent_module_path:
|
| 222 |
+
if hasattr(current_module, part):
|
| 223 |
+
current_module = getattr(current_module, part)
|
| 224 |
+
elif hasattr(current_module, 'base_layer') and hasattr(current_module.base_layer, part):
|
| 225 |
+
# This case is unlikely here as we are navigating *to* the layer,
|
| 226 |
+
# not trying to access a sub-component of a base_layer.
|
| 227 |
+
# PEFT wrapping affects the layer itself, not its parent structure.
|
| 228 |
+
current_module = getattr(current_module.base_layer, part)
|
| 229 |
else:
|
| 230 |
+
raise AttributeError(f"Submodule '{part}' not found in path '{'.'.join(parent_module_path)}' within {key}")
|
| 231 |
+
|
| 232 |
+
# Now, current_module is the layer whose parameter we want to patch
|
| 233 |
+
# e.g., if key was transformer.blocks.0.attn1.to_q.bias,
|
| 234 |
+
# current_module is the to_q Linear layer (or LoraLayer wrapping it)
|
| 235 |
+
|
| 236 |
+
layer_to_modify = current_module
|
| 237 |
+
# If PEFT wrapped the Linear layer (common for attention q,k,v,o and ffn projections)
|
| 238 |
+
if hasattr(layer_to_modify, "base_layer") and isinstance(layer_to_modify.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
|
| 239 |
+
actual_param_owner = layer_to_modify.base_layer
|
| 240 |
+
else: # For non-wrapped layers like LayerNorm, or if it's already the base_layer
|
| 241 |
+
actual_param_owner = layer_to_modify
|
| 242 |
|
| 243 |
+
if not hasattr(actual_param_owner, param_name_to_patch):
|
| 244 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Parameter '{param_name_to_patch}' not found in layer '{actual_param_owner}'. Layer type: {type(actual_param_owner)}")
|
| 245 |
+
unpatched_keys_count += 1
|
| 246 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
original_param = getattr(actual_param_owner, param_name_to_patch)
|
| 249 |
+
|
| 250 |
+
if original_param is None and param_name_to_patch == "bias":
|
| 251 |
+
logger.info(f"Key '{key}': Original bias is None. Attempting to initialize.")
|
| 252 |
+
if isinstance(actual_param_owner, torch.nn.Linear) or isinstance(actual_param_owner, torch.nn.LayerNorm):
|
| 253 |
+
# For LayerNorm, bias exists if elementwise_affine=True (default).
|
| 254 |
+
# If it was False, we are making it affine by adding a bias.
|
| 255 |
+
# For Linear, if bias was False, we are adding one.
|
| 256 |
+
actual_param_owner.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
|
| 257 |
+
original_param = actual_param_owner.bias
|
| 258 |
+
logger.info(f"Key '{key}': Initialized bias for {type(actual_param_owner)}.")
|
| 259 |
+
else:
|
| 260 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Original bias is None and layer '{actual_param_owner}' is not Linear or LayerNorm. Cannot initialize.")
|
| 261 |
+
unpatched_keys_count +=1
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
# Special handling for RMSNorm which typically has no bias
|
| 265 |
+
if isinstance(actual_param_owner, torch.nn.RMSNorm) and param_name_to_patch == "bias":
|
| 266 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Layer '{actual_param_owner}' is RMSNorm which has no bias parameter. Skipping bias diff.")
|
| 267 |
unpatched_keys_count +=1
|
| 268 |
continue
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
if original_param is not None:
|
| 272 |
if original_param.shape != diff_tensor.shape:
|
| 273 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Shape mismatch. Model param: {original_param.shape}, LoRA diff: {diff_tensor.shape}. Layer: {actual_param_owner}")
|
| 274 |
+
unpatched_keys_count += 1
|
| 275 |
continue
|
| 276 |
with torch.no_grad():
|
| 277 |
original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
|
| 278 |
+
# logger.info(f"Successfully applied diff to '{key}'") # Too verbose, will log summary
|
| 279 |
+
patched_keys_count += 1
|
| 280 |
else:
|
| 281 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Original parameter '{param_name_to_patch}' is None and was not initialized. Layer: {actual_param_owner}")
|
| 282 |
+
unpatched_keys_count += 1
|
|
|
|
| 283 |
|
| 284 |
except AttributeError as e:
|
| 285 |
+
skipped_keys_details.append(f"Key: {key}, Reason: AttributeError - {e}")
|
| 286 |
+
unpatched_keys_count += 1
|
| 287 |
except Exception as e:
|
| 288 |
+
skipped_keys_details.append(f"Key: {key}, Reason: General Exception - {e}")
|
| 289 |
+
unpatched_keys_count += 1
|
|
|
|
| 290 |
|
| 291 |
+
logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
|
| 292 |
+
if unpatched_keys_count > 0:
|
| 293 |
+
logger.warning("Details of unpatched/skipped keys:")
|
| 294 |
+
for detail in skipped_keys_details:
|
| 295 |
+
logger.warning(f" - {detail}")
|
| 296 |
|
| 297 |
# --- Model Loading ---
|
| 298 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
|
|
|
| 419 |
width_input,
|
| 420 |
num_frames_input,
|
| 421 |
guidance_scale_input,
|
| 422 |
+
steps,
|
| 423 |
fps_input
|
| 424 |
],
|
| 425 |
outputs=video_output
|