| import argparse |
| import pathlib |
| from typing import Any, Dict, Tuple |
|
|
| import torch |
| from accelerate import init_empty_weights |
| from huggingface_hub import hf_hub_download, snapshot_download |
| from safetensors.torch import load_file |
| from transformers import ( |
| AutoProcessor, |
| AutoTokenizer, |
| CLIPImageProcessor, |
| CLIPVisionModel, |
| CLIPVisionModelWithProjection, |
| UMT5EncoderModel, |
| ) |
|
|
| from diffusers import ( |
| AutoencoderKLWan, |
| UniPCMultistepScheduler, |
| WanAnimatePipeline, |
| WanAnimateTransformer3DModel, |
| WanImageToVideoPipeline, |
| WanPipeline, |
| WanTransformer3DModel, |
| WanVACEPipeline, |
| WanVACETransformer3DModel, |
| ) |
|
|
|
|
| TRANSFORMER_KEYS_RENAME_DICT = { |
| "time_embedding.0": "condition_embedder.time_embedder.linear_1", |
| "time_embedding.2": "condition_embedder.time_embedder.linear_2", |
| "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
| "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
| "time_projection.1": "condition_embedder.time_proj", |
| "head.modulation": "scale_shift_table", |
| "head.head": "proj_out", |
| "modulation": "scale_shift_table", |
| "ffn.0": "ffn.net.0.proj", |
| "ffn.2": "ffn.net.2", |
| |
| |
| |
| "norm2": "norm__placeholder", |
| "norm3": "norm2", |
| "norm__placeholder": "norm3", |
| |
| "img_emb.proj.0": "condition_embedder.image_embedder.norm1", |
| "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", |
| "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", |
| "img_emb.proj.4": "condition_embedder.image_embedder.norm2", |
| |
| "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", |
| |
| "self_attn.q": "attn1.to_q", |
| "self_attn.k": "attn1.to_k", |
| "self_attn.v": "attn1.to_v", |
| "self_attn.o": "attn1.to_out.0", |
| "self_attn.norm_q": "attn1.norm_q", |
| "self_attn.norm_k": "attn1.norm_k", |
| "cross_attn.q": "attn2.to_q", |
| "cross_attn.k": "attn2.to_k", |
| "cross_attn.v": "attn2.to_v", |
| "cross_attn.o": "attn2.to_out.0", |
| "cross_attn.norm_q": "attn2.norm_q", |
| "cross_attn.norm_k": "attn2.norm_k", |
| "attn2.to_k_img": "attn2.add_k_proj", |
| "attn2.to_v_img": "attn2.add_v_proj", |
| "attn2.norm_k_img": "attn2.norm_added_k", |
| } |
|
|
| VACE_TRANSFORMER_KEYS_RENAME_DICT = { |
| "time_embedding.0": "condition_embedder.time_embedder.linear_1", |
| "time_embedding.2": "condition_embedder.time_embedder.linear_2", |
| "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
| "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
| "time_projection.1": "condition_embedder.time_proj", |
| "head.modulation": "scale_shift_table", |
| "head.head": "proj_out", |
| "modulation": "scale_shift_table", |
| "ffn.0": "ffn.net.0.proj", |
| "ffn.2": "ffn.net.2", |
| |
| |
| |
| "norm2": "norm__placeholder", |
| "norm3": "norm2", |
| "norm__placeholder": "norm3", |
| |
| |
| |
| |
| |
| |
| |
| |
| "self_attn.q": "attn1.to_q", |
| "self_attn.k": "attn1.to_k", |
| "self_attn.v": "attn1.to_v", |
| "self_attn.o": "attn1.to_out.0", |
| "self_attn.norm_q": "attn1.norm_q", |
| "self_attn.norm_k": "attn1.norm_k", |
| "cross_attn.q": "attn2.to_q", |
| "cross_attn.k": "attn2.to_k", |
| "cross_attn.v": "attn2.to_v", |
| "cross_attn.o": "attn2.to_out.0", |
| "cross_attn.norm_q": "attn2.norm_q", |
| "cross_attn.norm_k": "attn2.norm_k", |
| "attn2.to_k_img": "attn2.add_k_proj", |
| "attn2.to_v_img": "attn2.add_v_proj", |
| "attn2.norm_k_img": "attn2.norm_added_k", |
| "before_proj": "proj_in", |
| "after_proj": "proj_out", |
| } |
|
|
| ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = { |
| "time_embedding.0": "condition_embedder.time_embedder.linear_1", |
| "time_embedding.2": "condition_embedder.time_embedder.linear_2", |
| "text_embedding.0": "condition_embedder.text_embedder.linear_1", |
| "text_embedding.2": "condition_embedder.text_embedder.linear_2", |
| "time_projection.1": "condition_embedder.time_proj", |
| "head.modulation": "scale_shift_table", |
| "head.head": "proj_out", |
| "modulation": "scale_shift_table", |
| "ffn.0": "ffn.net.0.proj", |
| "ffn.2": "ffn.net.2", |
| |
| |
| |
| "norm2": "norm__placeholder", |
| "norm3": "norm2", |
| "norm__placeholder": "norm3", |
| "img_emb.proj.0": "condition_embedder.image_embedder.norm1", |
| "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", |
| "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", |
| "img_emb.proj.4": "condition_embedder.image_embedder.norm2", |
| |
| "self_attn.q": "attn1.to_q", |
| "self_attn.k": "attn1.to_k", |
| "self_attn.v": "attn1.to_v", |
| "self_attn.o": "attn1.to_out.0", |
| "self_attn.norm_q": "attn1.norm_q", |
| "self_attn.norm_k": "attn1.norm_k", |
| "cross_attn.q": "attn2.to_q", |
| "cross_attn.k": "attn2.to_k", |
| "cross_attn.v": "attn2.to_v", |
| "cross_attn.o": "attn2.to_out.0", |
| "cross_attn.norm_q": "attn2.norm_q", |
| "cross_attn.norm_k": "attn2.norm_k", |
| "cross_attn.k_img": "attn2.to_k_img", |
| "cross_attn.v_img": "attn2.to_v_img", |
| "cross_attn.norm_k_img": "attn2.norm_k_img", |
| |
| "attn2.to_k_img": "attn2.add_k_proj", |
| "attn2.to_v_img": "attn2.add_v_proj", |
| "attn2.norm_k_img": "attn2.norm_added_k", |
| |
| |
| |
| "motion_encoder.enc.fc": "motion_encoder.motion_network", |
| "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight", |
| |
| "face_encoder.conv1_local.conv": "face_encoder.conv1_local", |
| "face_encoder.conv2.conv": "face_encoder.conv2", |
| "face_encoder.conv3.conv": "face_encoder.conv3", |
| |
| } |
|
|
|
|
| |
| def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None: |
| """ |
| Convert all motion encoder weights for Animate model. |
| |
| In the original model: |
| - All Linear layers in fc use EqualLinear |
| - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately) |
| - Blur kernels are stored as buffers in Sequential modules |
| - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)] |
| |
| Conversion strategy: |
| 1. Drop .kernel buffers (blur kernels) |
| 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu) |
| """ |
| |
| if ".weight" not in key and ".bias" not in key and ".kernel" not in key: |
| return |
|
|
| |
| |
| |
| if ".kernel" in key and "motion_encoder" in key: |
| |
| state_dict.pop(key, None) |
| return |
|
|
| |
| if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key): |
| parts = key.split(".") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| convs_idx = parts.index("convs") if "convs" in parts else -1 |
| if convs_idx >= 0 and len(parts) - convs_idx >= 2: |
| bias = False |
| |
| sequential_idx = int(parts[convs_idx + 1]) |
| if sequential_idx == 0: |
| if key.endswith(".weight"): |
| new_key = "motion_encoder.conv_in.weight" |
| elif key.endswith(".bias"): |
| new_key = "motion_encoder.conv_in.act_fn.bias" |
| bias = True |
| elif sequential_idx == final_conv_idx: |
| if key.endswith(".weight"): |
| new_key = "motion_encoder.conv_out.weight" |
| else: |
| |
| prefix = "motion_encoder.res_blocks." |
|
|
| layer_name = parts[convs_idx + 2] |
| if layer_name == "skip": |
| layer_name = "conv_skip" |
|
|
| if key.endswith(".weight"): |
| param_name = "weight" |
| elif key.endswith(".bias"): |
| param_name = "act_fn.bias" |
| bias = True |
|
|
| suffix_parts = [str(sequential_idx - 1), layer_name, param_name] |
| suffix = ".".join(suffix_parts) |
| new_key = prefix + suffix |
|
|
| param = state_dict.pop(key) |
| if bias: |
| param = param.squeeze() |
| state_dict[new_key] = param |
| return |
| return |
| return |
|
|
|
|
| def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None: |
| """ |
| Convert face adapter weights for the Animate model. |
| |
| The original model uses a fused KV projection but the diffusers models uses separate K and V projections. |
| """ |
| |
| if ".weight" not in key and ".bias" not in key: |
| return |
|
|
| prefix = "face_adapter." |
| if ".fuser_blocks." in key: |
| parts = key.split(".") |
|
|
| module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1 |
| if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3: |
| block_idx = parts[module_list_idx + 1] |
| layer_name = parts[module_list_idx + 2] |
| param_name = parts[module_list_idx + 3] |
|
|
| if layer_name == "linear1_kv": |
| layer_name_k = "to_k" |
| layer_name_v = "to_v" |
|
|
| suffix_k = ".".join([block_idx, layer_name_k, param_name]) |
| suffix_v = ".".join([block_idx, layer_name_v, param_name]) |
| new_key_k = prefix + suffix_k |
| new_key_v = prefix + suffix_v |
|
|
| kv_proj = state_dict.pop(key) |
| k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0) |
| state_dict[new_key_k] = k_proj |
| state_dict[new_key_v] = v_proj |
| return |
| else: |
| if layer_name == "q_norm": |
| new_layer_name = "norm_q" |
| elif layer_name == "k_norm": |
| new_layer_name = "norm_k" |
| elif layer_name == "linear1_q": |
| new_layer_name = "to_q" |
| elif layer_name == "linear2": |
| new_layer_name = "to_out" |
|
|
| suffix_parts = [block_idx, new_layer_name, param_name] |
| suffix = ".".join(suffix_parts) |
| new_key = prefix + suffix |
| state_dict[new_key] = state_dict.pop(key) |
| return |
| return |
|
|
|
|
| TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
| VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} |
| ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| "motion_encoder": convert_animate_motion_encoder_weights, |
| "face_adapter": convert_animate_face_adapter_weights, |
| } |
|
|
|
|
| def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
| state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
| def load_sharded_safetensors(dir: pathlib.Path): |
| file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) |
| state_dict = {} |
| for path in file_paths: |
| state_dict.update(load_file(path)) |
| return state_dict |
|
|
|
|
| def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: |
| if model_type == "Wan-T2V-1.3B": |
| config = { |
| "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "in_channels": 16, |
| "num_attention_heads": 12, |
| "num_layers": 30, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan-T2V-14B": |
| config = { |
| "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 16, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan-I2V-14B-480p": |
| config = { |
| "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", |
| "diffusers_config": { |
| "image_dim": 1280, |
| "added_kv_proj_dim": 5120, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 36, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan-I2V-14B-720p": |
| config = { |
| "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", |
| "diffusers_config": { |
| "image_dim": 1280, |
| "added_kv_proj_dim": 5120, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 36, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan-FLF2V-14B-720P": |
| config = { |
| "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", |
| "diffusers_config": { |
| "image_dim": 1280, |
| "added_kv_proj_dim": 5120, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 36, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| "rope_max_seq_len": 1024, |
| "pos_embed_seq_len": 257 * 2, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan-VACE-1.3B": |
| config = { |
| "model_id": "Wan-AI/Wan2.1-VACE-1.3B", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 8960, |
| "freq_dim": 256, |
| "in_channels": 16, |
| "num_attention_heads": 12, |
| "num_layers": 30, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], |
| "vace_in_channels": 96, |
| }, |
| } |
| RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan-VACE-14B": |
| config = { |
| "model_id": "Wan-AI/Wan2.1-VACE-14B", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 16, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], |
| "vace_in_channels": 96, |
| }, |
| } |
| RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan2.2-VACE-Fun-14B": |
| config = { |
| "model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 16, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], |
| "vace_in_channels": 96, |
| }, |
| } |
| RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan2.2-I2V-14B-720p": |
| config = { |
| "model_id": "Wan-AI/Wan2.2-I2V-A14B", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 36, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan2.2-T2V-A14B": |
| config = { |
| "model_id": "Wan-AI/Wan2.2-T2V-A14B", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 16, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan2.2-TI2V-5B": |
| config = { |
| "model_id": "Wan-AI/Wan2.2-TI2V-5B", |
| "diffusers_config": { |
| "added_kv_proj_dim": None, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 14336, |
| "freq_dim": 256, |
| "in_channels": 48, |
| "num_attention_heads": 24, |
| "num_layers": 30, |
| "out_channels": 48, |
| "patch_size": [1, 2, 2], |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| }, |
| } |
| RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP |
| elif model_type == "Wan2.2-Animate-14B": |
| config = { |
| "model_id": "Wan-AI/Wan2.2-Animate-14B", |
| "diffusers_config": { |
| "image_dim": 1280, |
| "added_kv_proj_dim": 5120, |
| "attention_head_dim": 128, |
| "cross_attn_norm": True, |
| "eps": 1e-06, |
| "ffn_dim": 13824, |
| "freq_dim": 256, |
| "in_channels": 36, |
| "num_attention_heads": 40, |
| "num_layers": 40, |
| "out_channels": 16, |
| "patch_size": (1, 2, 2), |
| "qk_norm": "rms_norm_across_heads", |
| "text_dim": 4096, |
| "rope_max_seq_len": 1024, |
| "pos_embed_seq_len": None, |
| "motion_encoder_size": 512, |
| "motion_style_dim": 512, |
| "motion_dim": 20, |
| "motion_encoder_dim": 512, |
| "face_encoder_hidden_dim": 1024, |
| "face_encoder_num_heads": 4, |
| "inject_face_latents_blocks": 5, |
| }, |
| } |
| RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT |
| SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP |
| return config, RENAME_DICT, SPECIAL_KEYS_REMAP |
|
|
|
|
| def convert_transformer(model_type: str, stage: str = None): |
| config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) |
|
|
| diffusers_config = config["diffusers_config"] |
| model_id = config["model_id"] |
| model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) |
|
|
| if stage is not None: |
| model_dir = model_dir / stage |
|
|
| original_state_dict = load_sharded_safetensors(model_dir) |
|
|
| with init_empty_weights(): |
| if "Animate" in model_type: |
| transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) |
| elif "VACE" in model_type: |
| transformer = WanVACETransformer3DModel.from_config(diffusers_config) |
| else: |
| transformer = WanTransformer3DModel.from_config(diffusers_config) |
|
|
| for key in list(original_state_dict.keys()): |
| new_key = key[:] |
| for replace_key, rename_key in RENAME_DICT.items(): |
| new_key = new_key.replace(replace_key, rename_key) |
| update_state_dict_(original_state_dict, key, new_key) |
|
|
| for key in list(original_state_dict.keys()): |
| for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): |
| if special_key not in key: |
| continue |
| handler_fn_inplace(key, original_state_dict) |
|
|
| |
| transformer.load_state_dict(original_state_dict, strict=True, assign=True) |
|
|
| |
| transformer = transformer.to("cpu") |
|
|
| return transformer |
|
|
|
|
| def convert_vae(): |
| vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth") |
| old_state_dict = torch.load(vae_ckpt_path, weights_only=True) |
| new_state_dict = {} |
|
|
| |
| middle_key_mapping = { |
| |
| "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", |
| "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", |
| "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", |
| "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", |
| "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", |
| "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", |
| "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", |
| "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", |
| "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", |
| "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", |
| "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", |
| "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", |
| |
| "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", |
| "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", |
| "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", |
| "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", |
| "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", |
| "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", |
| "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", |
| "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", |
| "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", |
| "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", |
| "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", |
| "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", |
| } |
|
|
| |
| attention_mapping = { |
| |
| "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", |
| "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", |
| "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", |
| "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", |
| "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", |
| |
| "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", |
| "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", |
| "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", |
| "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", |
| "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", |
| } |
|
|
| |
| head_mapping = { |
| |
| "encoder.head.0.gamma": "encoder.norm_out.gamma", |
| "encoder.head.2.bias": "encoder.conv_out.bias", |
| "encoder.head.2.weight": "encoder.conv_out.weight", |
| |
| "decoder.head.0.gamma": "decoder.norm_out.gamma", |
| "decoder.head.2.bias": "decoder.conv_out.bias", |
| "decoder.head.2.weight": "decoder.conv_out.weight", |
| } |
|
|
| |
| quant_mapping = { |
| "conv1.weight": "quant_conv.weight", |
| "conv1.bias": "quant_conv.bias", |
| "conv2.weight": "post_quant_conv.weight", |
| "conv2.bias": "post_quant_conv.bias", |
| } |
|
|
| |
| for key, value in old_state_dict.items(): |
| |
| if key in middle_key_mapping: |
| new_key = middle_key_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key in attention_mapping: |
| new_key = attention_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key in head_mapping: |
| new_key = head_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key in quant_mapping: |
| new_key = quant_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key == "encoder.conv1.weight": |
| new_state_dict["encoder.conv_in.weight"] = value |
| elif key == "encoder.conv1.bias": |
| new_state_dict["encoder.conv_in.bias"] = value |
| |
| elif key == "decoder.conv1.weight": |
| new_state_dict["decoder.conv_in.weight"] = value |
| elif key == "decoder.conv1.bias": |
| new_state_dict["decoder.conv_in.bias"] = value |
| |
| elif key.startswith("encoder.downsamples."): |
| |
| new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") |
|
|
| |
| if ".residual.0.gamma" in new_key: |
| new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") |
| elif ".residual.2.bias" in new_key: |
| new_key = new_key.replace(".residual.2.bias", ".conv1.bias") |
| elif ".residual.2.weight" in new_key: |
| new_key = new_key.replace(".residual.2.weight", ".conv1.weight") |
| elif ".residual.3.gamma" in new_key: |
| new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") |
| elif ".residual.6.bias" in new_key: |
| new_key = new_key.replace(".residual.6.bias", ".conv2.bias") |
| elif ".residual.6.weight" in new_key: |
| new_key = new_key.replace(".residual.6.weight", ".conv2.weight") |
| elif ".shortcut.bias" in new_key: |
| new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") |
| elif ".shortcut.weight" in new_key: |
| new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") |
|
|
| new_state_dict[new_key] = value |
|
|
| |
| elif key.startswith("decoder.upsamples."): |
| |
| parts = key.split(".") |
| block_idx = int(parts[2]) |
|
|
| |
| if "residual" in key: |
| if block_idx in [0, 1, 2]: |
| new_block_idx = 0 |
| resnet_idx = block_idx |
| elif block_idx in [4, 5, 6]: |
| new_block_idx = 1 |
| resnet_idx = block_idx - 4 |
| elif block_idx in [8, 9, 10]: |
| new_block_idx = 2 |
| resnet_idx = block_idx - 8 |
| elif block_idx in [12, 13, 14]: |
| new_block_idx = 3 |
| resnet_idx = block_idx - 12 |
| else: |
| |
| new_state_dict[key] = value |
| continue |
|
|
| |
| if ".residual.0.gamma" in key: |
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" |
| elif ".residual.2.bias" in key: |
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" |
| elif ".residual.2.weight" in key: |
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" |
| elif ".residual.3.gamma" in key: |
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" |
| elif ".residual.6.bias" in key: |
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" |
| elif ".residual.6.weight" in key: |
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" |
| else: |
| new_key = key |
|
|
| new_state_dict[new_key] = value |
|
|
| |
| elif ".shortcut." in key: |
| if block_idx == 4: |
| new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") |
| new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") |
| else: |
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
| new_key = new_key.replace(".shortcut.", ".conv_shortcut.") |
|
|
| new_state_dict[new_key] = value |
|
|
| |
| elif ".resample." in key or ".time_conv." in key: |
| if block_idx == 3: |
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") |
| elif block_idx == 7: |
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") |
| elif block_idx == 11: |
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") |
| else: |
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
|
|
| new_state_dict[new_key] = value |
| else: |
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
| new_state_dict[new_key] = value |
| else: |
| |
| new_state_dict[key] = value |
|
|
| with init_empty_weights(): |
| vae = AutoencoderKLWan() |
| vae.load_state_dict(new_state_dict, strict=True, assign=True) |
| return vae |
|
|
|
|
| vae22_diffusers_config = { |
| "base_dim": 160, |
| "z_dim": 48, |
| "is_residual": True, |
| "in_channels": 12, |
| "out_channels": 12, |
| "decoder_base_dim": 256, |
| "scale_factor_temporal": 4, |
| "scale_factor_spatial": 16, |
| "patch_size": 2, |
| "latents_mean": [ |
| -0.2289, |
| -0.0052, |
| -0.1323, |
| -0.2339, |
| -0.2799, |
| 0.0174, |
| 0.1838, |
| 0.1557, |
| -0.1382, |
| 0.0542, |
| 0.2813, |
| 0.0891, |
| 0.1570, |
| -0.0098, |
| 0.0375, |
| -0.1825, |
| -0.2246, |
| -0.1207, |
| -0.0698, |
| 0.5109, |
| 0.2665, |
| -0.2108, |
| -0.2158, |
| 0.2502, |
| -0.2055, |
| -0.0322, |
| 0.1109, |
| 0.1567, |
| -0.0729, |
| 0.0899, |
| -0.2799, |
| -0.1230, |
| -0.0313, |
| -0.1649, |
| 0.0117, |
| 0.0723, |
| -0.2839, |
| -0.2083, |
| -0.0520, |
| 0.3748, |
| 0.0152, |
| 0.1957, |
| 0.1433, |
| -0.2944, |
| 0.3573, |
| -0.0548, |
| -0.1681, |
| -0.0667, |
| ], |
| "latents_std": [ |
| 0.4765, |
| 1.0364, |
| 0.4514, |
| 1.1677, |
| 0.5313, |
| 0.4990, |
| 0.4818, |
| 0.5013, |
| 0.8158, |
| 1.0344, |
| 0.5894, |
| 1.0901, |
| 0.6885, |
| 0.6165, |
| 0.8454, |
| 0.4978, |
| 0.5759, |
| 0.3523, |
| 0.7135, |
| 0.6804, |
| 0.5833, |
| 1.4146, |
| 0.8986, |
| 0.5659, |
| 0.7069, |
| 0.5338, |
| 0.4889, |
| 0.4917, |
| 0.4069, |
| 0.4999, |
| 0.6866, |
| 0.4093, |
| 0.5709, |
| 0.6065, |
| 0.6415, |
| 0.4944, |
| 0.5726, |
| 1.2042, |
| 0.5458, |
| 1.6887, |
| 0.3971, |
| 1.0600, |
| 0.3943, |
| 0.5537, |
| 0.5444, |
| 0.4089, |
| 0.7468, |
| 0.7744, |
| ], |
| "clip_output": False, |
| } |
|
|
|
|
| def convert_vae_22(): |
| vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth") |
| old_state_dict = torch.load(vae_ckpt_path, weights_only=True) |
| new_state_dict = {} |
|
|
| |
| middle_key_mapping = { |
| |
| "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", |
| "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", |
| "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", |
| "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", |
| "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", |
| "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", |
| "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", |
| "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", |
| "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", |
| "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", |
| "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", |
| "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", |
| |
| "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", |
| "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", |
| "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", |
| "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", |
| "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", |
| "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", |
| "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", |
| "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", |
| "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", |
| "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", |
| "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", |
| "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", |
| } |
|
|
| |
| attention_mapping = { |
| |
| "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", |
| "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", |
| "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", |
| "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", |
| "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", |
| |
| "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", |
| "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", |
| "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", |
| "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", |
| "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", |
| } |
|
|
| |
| head_mapping = { |
| |
| "encoder.head.0.gamma": "encoder.norm_out.gamma", |
| "encoder.head.2.bias": "encoder.conv_out.bias", |
| "encoder.head.2.weight": "encoder.conv_out.weight", |
| |
| "decoder.head.0.gamma": "decoder.norm_out.gamma", |
| "decoder.head.2.bias": "decoder.conv_out.bias", |
| "decoder.head.2.weight": "decoder.conv_out.weight", |
| } |
|
|
| |
| quant_mapping = { |
| "conv1.weight": "quant_conv.weight", |
| "conv1.bias": "quant_conv.bias", |
| "conv2.weight": "post_quant_conv.weight", |
| "conv2.bias": "post_quant_conv.bias", |
| } |
|
|
| |
| for key, value in old_state_dict.items(): |
| |
| if key in middle_key_mapping: |
| new_key = middle_key_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key in attention_mapping: |
| new_key = attention_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key in head_mapping: |
| new_key = head_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key in quant_mapping: |
| new_key = quant_mapping[key] |
| new_state_dict[new_key] = value |
| |
| elif key == "encoder.conv1.weight": |
| new_state_dict["encoder.conv_in.weight"] = value |
| elif key == "encoder.conv1.bias": |
| new_state_dict["encoder.conv_in.bias"] = value |
| |
| elif key == "decoder.conv1.weight": |
| new_state_dict["decoder.conv_in.weight"] = value |
| elif key == "decoder.conv1.bias": |
| new_state_dict["decoder.conv_in.bias"] = value |
| |
| elif key.startswith("encoder.downsamples."): |
| |
| new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") |
|
|
| |
| if "residual" in new_key or "shortcut" in new_key: |
| |
| new_key = new_key.replace(".downsamples.", ".resnets.") |
|
|
| |
| if ".residual.0.gamma" in new_key: |
| new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") |
| elif ".residual.2.weight" in new_key: |
| new_key = new_key.replace(".residual.2.weight", ".conv1.weight") |
| elif ".residual.2.bias" in new_key: |
| new_key = new_key.replace(".residual.2.bias", ".conv1.bias") |
| elif ".residual.3.gamma" in new_key: |
| new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") |
| elif ".residual.6.weight" in new_key: |
| new_key = new_key.replace(".residual.6.weight", ".conv2.weight") |
| elif ".residual.6.bias" in new_key: |
| new_key = new_key.replace(".residual.6.bias", ".conv2.bias") |
| elif ".shortcut.weight" in new_key: |
| new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") |
| elif ".shortcut.bias" in new_key: |
| new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") |
|
|
| |
| elif "resample" in new_key or "time_conv" in new_key: |
| |
| parts = new_key.split(".") |
| |
| |
| if len(parts) >= 4 and parts[3] == "downsamples": |
| |
| new_parts = parts[:3] + ["downsampler"] + parts[5:] |
| new_key = ".".join(new_parts) |
|
|
| new_state_dict[new_key] = value |
|
|
| |
| elif key.startswith("decoder.upsamples."): |
| |
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") |
|
|
| |
| if "residual" in new_key or "shortcut" in new_key: |
| |
| new_key = new_key.replace(".upsamples.", ".resnets.") |
|
|
| |
| if ".residual.0.gamma" in new_key: |
| new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") |
| elif ".residual.2.weight" in new_key: |
| new_key = new_key.replace(".residual.2.weight", ".conv1.weight") |
| elif ".residual.2.bias" in new_key: |
| new_key = new_key.replace(".residual.2.bias", ".conv1.bias") |
| elif ".residual.3.gamma" in new_key: |
| new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") |
| elif ".residual.6.weight" in new_key: |
| new_key = new_key.replace(".residual.6.weight", ".conv2.weight") |
| elif ".residual.6.bias" in new_key: |
| new_key = new_key.replace(".residual.6.bias", ".conv2.bias") |
| elif ".shortcut.weight" in new_key: |
| new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") |
| elif ".shortcut.bias" in new_key: |
| new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") |
|
|
| |
| elif "resample" in new_key or "time_conv" in new_key: |
| |
| parts = new_key.split(".") |
| |
| |
| if len(parts) >= 4 and parts[3] == "upsamples": |
| |
| new_parts = parts[:3] + ["upsampler"] + parts[5:] |
| new_key = ".".join(new_parts) |
|
|
| new_state_dict[new_key] = value |
| else: |
| |
| new_state_dict[key] = value |
|
|
| with init_empty_weights(): |
| vae = AutoencoderKLWan(**vae22_diffusers_config) |
| vae.load_state_dict(new_state_dict, strict=True, assign=True) |
| return vae |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_type", type=str, default=None) |
| parser.add_argument("--output_path", type=str, required=True) |
| parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"]) |
| return parser.parse_args() |
|
|
|
|
| DTYPE_MAPPING = { |
| "fp32": torch.float32, |
| "fp16": torch.float16, |
| "bf16": torch.bfloat16, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| args = get_args() |
|
|
| if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type: |
| transformer = convert_transformer(args.model_type, stage="high_noise_model") |
| transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") |
| else: |
| transformer = convert_transformer(args.model_type) |
| transformer_2 = None |
|
|
| if "Wan2.2" in args.model_type and "TI2V" in args.model_type: |
| vae = convert_vae_22() |
| else: |
| vae = convert_vae() |
|
|
| text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) |
| tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") |
| if "FLF2V" in args.model_type: |
| flow_shift = 16.0 |
| elif "TI2V" in args.model_type or "Animate" in args.model_type: |
| flow_shift = 5.0 |
| else: |
| flow_shift = 3.0 |
| scheduler = UniPCMultistepScheduler( |
| prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift |
| ) |
|
|
| |
| if args.dtype != "none": |
| dtype = DTYPE_MAPPING[args.dtype] |
| transformer.to(dtype) |
| if transformer_2 is not None: |
| transformer_2.to(dtype) |
|
|
| if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type: |
| pipe = WanImageToVideoPipeline( |
| transformer=transformer, |
| transformer_2=transformer_2, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| boundary_ratio=0.9, |
| ) |
| elif "Wan2.2" and "T2V" in args.model_type: |
| pipe = WanPipeline( |
| transformer=transformer, |
| transformer_2=transformer_2, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| boundary_ratio=0.875, |
| ) |
| elif "Wan2.2" and "TI2V" in args.model_type: |
| pipe = WanPipeline( |
| transformer=transformer, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| expand_timesteps=True, |
| ) |
| elif "I2V" in args.model_type or "FLF2V" in args.model_type: |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 |
| ) |
| image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
| pipe = WanImageToVideoPipeline( |
| transformer=transformer, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| image_encoder=image_encoder, |
| image_processor=image_processor, |
| ) |
| elif "Wan2.2-VACE" in args.model_type: |
| pipe = WanVACEPipeline( |
| transformer=transformer, |
| transformer_2=transformer_2, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| boundary_ratio=0.875, |
| ) |
| elif "Wan-VACE" in args.model_type: |
| pipe = WanVACEPipeline( |
| transformer=transformer, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| ) |
| elif "Animate" in args.model_type: |
| image_encoder = CLIPVisionModel.from_pretrained( |
| "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 |
| ) |
| image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") |
|
|
| pipe = WanAnimatePipeline( |
| transformer=transformer, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| image_encoder=image_encoder, |
| image_processor=image_processor, |
| ) |
| else: |
| pipe = WanPipeline( |
| transformer=transformer, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| vae=vae, |
| scheduler=scheduler, |
| ) |
|
|
| pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |