| | import argparse |
| | from contextlib import nullcontext |
| |
|
| | import torch |
| | from accelerate import init_empty_weights |
| | from safetensors.torch import load_file |
| | from transformers import T5EncoderModel, T5Tokenizer |
| |
|
| | from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel |
| | from diffusers.utils.import_utils import is_accelerate_available |
| |
|
| |
|
| | CTX = init_empty_weights if is_accelerate_available() else nullcontext |
| |
|
| | TOKENIZER_MAX_LENGTH = 256 |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--transformer_checkpoint_path", default=None, type=str) |
| | parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str) |
| | parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str) |
| | parser.add_argument("--output_path", required=True, type=str) |
| | parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") |
| | parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") |
| | parser.add_argument("--dtype", type=str, default=None) |
| |
|
| | args = parser.parse_args() |
| |
|
| |
|
| | |
| | |
| | def swap_scale_shift(weight, dim): |
| | shift, scale = weight.chunk(2, dim=0) |
| | new_weight = torch.cat([scale, shift], dim=0) |
| | return new_weight |
| |
|
| |
|
| | def swap_proj_gate(weight): |
| | proj, gate = weight.chunk(2, dim=0) |
| | new_weight = torch.cat([gate, proj], dim=0) |
| | return new_weight |
| |
|
| |
|
| | def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): |
| | original_state_dict = load_file(ckpt_path, device="cpu") |
| | new_state_dict = {} |
| |
|
| | |
| | new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") |
| | new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") |
| |
|
| | |
| | new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight") |
| | new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias") |
| | new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight") |
| | new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias") |
| | new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight") |
| | new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias") |
| | new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight") |
| | new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias") |
| | new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight") |
| | new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias") |
| | new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight") |
| | new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias") |
| |
|
| | |
| | num_layers = 48 |
| | for i in range(num_layers): |
| | block_prefix = f"transformer_blocks.{i}." |
| | old_prefix = f"blocks.{i}." |
| |
|
| | |
| | new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight") |
| | new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias") |
| | if i < num_layers - 1: |
| | new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop( |
| | old_prefix + "mod_y.weight" |
| | ) |
| | new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop( |
| | old_prefix + "mod_y.bias" |
| | ) |
| | else: |
| | new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop( |
| | old_prefix + "mod_y.weight" |
| | ) |
| | new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop( |
| | old_prefix + "mod_y.bias" |
| | ) |
| |
|
| | |
| | qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") |
| | q, k, v = qkv_weight.chunk(3, dim=0) |
| |
|
| | new_state_dict[block_prefix + "attn1.to_q.weight"] = q |
| | new_state_dict[block_prefix + "attn1.to_k.weight"] = k |
| | new_state_dict[block_prefix + "attn1.to_v.weight"] = v |
| | new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop( |
| | old_prefix + "attn.q_norm_x.weight" |
| | ) |
| | new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop( |
| | old_prefix + "attn.k_norm_x.weight" |
| | ) |
| | new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( |
| | old_prefix + "attn.proj_x.weight" |
| | ) |
| | new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias") |
| |
|
| | |
| | qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight") |
| | q, k, v = qkv_weight.chunk(3, dim=0) |
| |
|
| | new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q |
| | new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k |
| | new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v |
| | new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop( |
| | old_prefix + "attn.q_norm_y.weight" |
| | ) |
| | new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop( |
| | old_prefix + "attn.k_norm_y.weight" |
| | ) |
| | if i < num_layers - 1: |
| | new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop( |
| | old_prefix + "attn.proj_y.weight" |
| | ) |
| | new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop( |
| | old_prefix + "attn.proj_y.bias" |
| | ) |
| |
|
| | |
| | new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( |
| | original_state_dict.pop(old_prefix + "mlp_x.w1.weight") |
| | ) |
| | new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight") |
| | if i < num_layers - 1: |
| | new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( |
| | original_state_dict.pop(old_prefix + "mlp_y.w1.weight") |
| | ) |
| | new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop( |
| | old_prefix + "mlp_y.w2.weight" |
| | ) |
| |
|
| | |
| | new_state_dict["norm_out.linear.weight"] = swap_scale_shift( |
| | original_state_dict.pop("final_layer.mod.weight"), dim=0 |
| | ) |
| | new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0) |
| | new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") |
| | new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") |
| |
|
| | new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies") |
| |
|
| | print("Remaining Keys:", original_state_dict.keys()) |
| |
|
| | return new_state_dict |
| |
|
| |
|
| | def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path): |
| | encoder_state_dict = load_file(encoder_ckpt_path, device="cpu") |
| | decoder_state_dict = load_file(decoder_ckpt_path, device="cpu") |
| | new_state_dict = {} |
| |
|
| | |
| | prefix = "decoder." |
| |
|
| | |
| | new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight") |
| | new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias") |
| |
|
| | |
| | for i in range(3): |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.0.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.0.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.2.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.2.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.3.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.3.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.5.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( |
| | f"blocks.0.{i+1}.stack.5.bias" |
| | ) |
| |
|
| | |
| | down_block_layers = [6, 4, 3] |
| | for block in range(3): |
| | for i in range(down_block_layers[block]): |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.0.weight" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.0.bias" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.2.weight" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.2.bias" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.3.weight" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.3.bias" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.5.weight" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.blocks.{i}.stack.5.bias" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( |
| | f"blocks.{block+1}.proj.weight" |
| | ) |
| | new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias") |
| |
|
| | |
| | for i in range(3): |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.0.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.0.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.2.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.2.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.3.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.3.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.5.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( |
| | f"blocks.4.{i}.stack.5.bias" |
| | ) |
| |
|
| | |
| | new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight") |
| | new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias") |
| |
|
| | print("Remaining Decoder Keys:", decoder_state_dict.keys()) |
| |
|
| | |
| | prefix = "encoder." |
| |
|
| | new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight") |
| | new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias") |
| |
|
| | |
| | for i in range(3): |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.0.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.0.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.2.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.2.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.3.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.3.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.5.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+1}.stack.5.bias" |
| | ) |
| |
|
| | |
| | down_block_layers = [3, 4, 6] |
| | for block in range(3): |
| | new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.0.weight" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.0.bias" |
| | ) |
| |
|
| | for i in range(down_block_layers[block]): |
| | |
| | new_state_dict[ |
| | f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" |
| | ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") |
| | new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.stack.0.bias" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.stack.2.weight" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.stack.2.bias" |
| | ) |
| | new_state_dict[ |
| | f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" |
| | ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight") |
| | new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.stack.3.bias" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.stack.5.weight" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.stack.5.bias" |
| | ) |
| |
|
| | |
| | qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight") |
| | q, k, v = qkv_weight.chunk(3, dim=0) |
| |
|
| | new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q |
| | new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k |
| | new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v |
| | new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" |
| | ) |
| | new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" |
| | ) |
| |
|
| | |
| | for i in range(3): |
| | |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.0.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.0.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.2.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.2.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.3.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.3.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.5.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.stack.5.bias" |
| | ) |
| |
|
| | |
| | qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight") |
| | q, k, v = qkv_weight.chunk(3, dim=0) |
| |
|
| | new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q |
| | new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k |
| | new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v |
| | new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.attn_block.attn.out.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.attn_block.attn.out.bias" |
| | ) |
| | new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.attn_block.norm.weight" |
| | ) |
| | new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( |
| | f"layers.{i+7}.attn_block.norm.bias" |
| | ) |
| |
|
| | |
| | new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight") |
| | new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias") |
| | new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight") |
| |
|
| | print("Remaining Encoder Keys:", encoder_state_dict.keys()) |
| |
|
| | return new_state_dict |
| |
|
| |
|
| | def main(args): |
| | if args.dtype is None: |
| | dtype = None |
| | if args.dtype == "fp16": |
| | dtype = torch.float16 |
| | elif args.dtype == "bf16": |
| | dtype = torch.bfloat16 |
| | elif args.dtype == "fp32": |
| | dtype = torch.float32 |
| | else: |
| | raise ValueError(f"Unsupported dtype: {args.dtype}") |
| |
|
| | transformer = None |
| | vae = None |
| |
|
| | if args.transformer_checkpoint_path is not None: |
| | converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( |
| | args.transformer_checkpoint_path |
| | ) |
| | transformer = MochiTransformer3DModel() |
| | transformer.load_state_dict(converted_transformer_state_dict, strict=True) |
| | if dtype is not None: |
| | transformer = transformer.to(dtype=dtype) |
| |
|
| | if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None: |
| | vae = AutoencoderKLMochi(latent_channels=12, out_channels=3) |
| | converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers( |
| | args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path |
| | ) |
| | vae.load_state_dict(converted_vae_state_dict, strict=True) |
| | if dtype is not None: |
| | vae = vae.to(dtype=dtype) |
| |
|
| | text_encoder_id = "google/t5-v1_1-xxl" |
| | tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) |
| | text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) |
| |
|
| | |
| | for param in text_encoder.parameters(): |
| | param.data = param.data.contiguous() |
| |
|
| | pipe = MochiPipeline( |
| | scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True), |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | transformer=transformer, |
| | ) |
| | pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main(args) |
| |
|