| | import argparse |
| | from typing import Any, Dict |
| |
|
| | import torch |
| | from transformers import T5EncoderModel, T5Tokenizer |
| |
|
| | from diffusers import ( |
| | AutoencoderKLCogVideoX, |
| | CogVideoXDDIMScheduler, |
| | CogVideoXImageToVideoPipeline, |
| | CogVideoXPipeline, |
| | CogVideoXTransformer3DModel, |
| | ) |
| |
|
| |
|
| | def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]): |
| | to_q_key = key.replace("query_key_value", "to_q") |
| | to_k_key = key.replace("query_key_value", "to_k") |
| | to_v_key = key.replace("query_key_value", "to_v") |
| | to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0) |
| | state_dict[to_q_key] = to_q |
| | state_dict[to_k_key] = to_k |
| | state_dict[to_v_key] = to_v |
| | state_dict.pop(key) |
| |
|
| |
|
| | def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]): |
| | layer_id, weight_or_bias = key.split(".")[-2:] |
| |
|
| | if "query" in key: |
| | new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}" |
| | elif "key" in key: |
| | new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}" |
| |
|
| | state_dict[new_key] = state_dict.pop(key) |
| |
|
| |
|
| | def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]): |
| | layer_id, _, weight_or_bias = key.split(".")[-3:] |
| |
|
| | weights_or_biases = state_dict[key].chunk(12, dim=0) |
| | norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9]) |
| | norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12]) |
| |
|
| | norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}" |
| | state_dict[norm1_key] = norm1_weights_or_biases |
| |
|
| | norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}" |
| | state_dict[norm2_key] = norm2_weights_or_biases |
| |
|
| | state_dict.pop(key) |
| |
|
| |
|
| | def remove_keys_inplace(key: str, state_dict: Dict[str, Any]): |
| | state_dict.pop(key) |
| |
|
| |
|
| | def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): |
| | key_split = key.split(".") |
| | layer_index = int(key_split[2]) |
| | replace_layer_index = 4 - 1 - layer_index |
| |
|
| | key_split[1] = "up_blocks" |
| | key_split[2] = str(replace_layer_index) |
| | new_key = ".".join(key_split) |
| |
|
| | state_dict[new_key] = state_dict.pop(key) |
| |
|
| |
|
| | TRANSFORMER_KEYS_RENAME_DICT = { |
| | "transformer.final_layernorm": "norm_final", |
| | "transformer": "transformer_blocks", |
| | "attention": "attn1", |
| | "mlp": "ff.net", |
| | "dense_h_to_4h": "0.proj", |
| | "dense_4h_to_h": "2", |
| | ".layers": "", |
| | "dense": "to_out.0", |
| | "input_layernorm": "norm1.norm", |
| | "post_attn1_layernorm": "norm2.norm", |
| | "time_embed.0": "time_embedding.linear_1", |
| | "time_embed.2": "time_embedding.linear_2", |
| | "ofs_embed.0": "ofs_embedding.linear_1", |
| | "ofs_embed.2": "ofs_embedding.linear_2", |
| | "mixins.patch_embed": "patch_embed", |
| | "mixins.final_layer.norm_final": "norm_out.norm", |
| | "mixins.final_layer.linear": "proj_out", |
| | "mixins.final_layer.adaLN_modulation.1": "norm_out.linear", |
| | "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", |
| | } |
| |
|
| | TRANSFORMER_SPECIAL_KEYS_REMAP = { |
| | "query_key_value": reassign_query_key_value_inplace, |
| | "query_layernorm_list": reassign_query_key_layernorm_inplace, |
| | "key_layernorm_list": reassign_query_key_layernorm_inplace, |
| | "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, |
| | "embed_tokens": remove_keys_inplace, |
| | "freqs_sin": remove_keys_inplace, |
| | "freqs_cos": remove_keys_inplace, |
| | "position_embedding": remove_keys_inplace, |
| | } |
| |
|
| | VAE_KEYS_RENAME_DICT = { |
| | "block.": "resnets.", |
| | "down.": "down_blocks.", |
| | "downsample": "downsamplers.0", |
| | "upsample": "upsamplers.0", |
| | "nin_shortcut": "conv_shortcut", |
| | "encoder.mid.block_1": "encoder.mid_block.resnets.0", |
| | "encoder.mid.block_2": "encoder.mid_block.resnets.1", |
| | "decoder.mid.block_1": "decoder.mid_block.resnets.0", |
| | "decoder.mid.block_2": "decoder.mid_block.resnets.1", |
| | } |
| |
|
| | VAE_SPECIAL_KEYS_REMAP = { |
| | "loss": remove_keys_inplace, |
| | "up.": replace_up_keys_inplace, |
| | } |
| |
|
| | TOKENIZER_MAX_LENGTH = 226 |
| |
|
| |
|
| | def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: |
| | state_dict = saved_dict |
| | if "model" in saved_dict.keys(): |
| | state_dict = state_dict["model"] |
| | if "module" in saved_dict.keys(): |
| | state_dict = state_dict["module"] |
| | if "state_dict" in saved_dict.keys(): |
| | state_dict = state_dict["state_dict"] |
| | return state_dict |
| |
|
| |
|
| | def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
| | state_dict[new_key] = state_dict.pop(old_key) |
| |
|
| |
|
| | def convert_transformer( |
| | ckpt_path: str, |
| | num_layers: int, |
| | num_attention_heads: int, |
| | use_rotary_positional_embeddings: bool, |
| | i2v: bool, |
| | dtype: torch.dtype, |
| | init_kwargs: Dict[str, Any], |
| | ): |
| | PREFIX_KEY = "model.diffusion_model." |
| |
|
| | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) |
| | transformer = CogVideoXTransformer3DModel( |
| | in_channels=32 if i2v else 16, |
| | num_layers=num_layers, |
| | num_attention_heads=num_attention_heads, |
| | use_rotary_positional_embeddings=use_rotary_positional_embeddings, |
| | ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, |
| | use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, |
| | **init_kwargs, |
| | ).to(dtype=dtype) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | new_key = key[len(PREFIX_KEY) :] |
| | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): |
| | new_key = new_key.replace(replace_key, rename_key) |
| | update_state_dict_inplace(original_state_dict, key, new_key) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): |
| | if special_key not in key: |
| | continue |
| | handler_fn_inplace(key, original_state_dict) |
| |
|
| | transformer.load_state_dict(original_state_dict, strict=True) |
| | return transformer |
| |
|
| |
|
| | def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype): |
| | init_kwargs = {"scaling_factor": scaling_factor} |
| | if version == "1.5": |
| | init_kwargs.update({"invert_scale_latents": True}) |
| |
|
| | original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) |
| | vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | new_key = key[:] |
| | for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): |
| | new_key = new_key.replace(replace_key, rename_key) |
| | update_state_dict_inplace(original_state_dict, key, new_key) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): |
| | if special_key not in key: |
| | continue |
| | handler_fn_inplace(key, original_state_dict) |
| |
|
| | vae.load_state_dict(original_state_dict, strict=True) |
| | return vae |
| |
|
| |
|
| | def get_transformer_init_kwargs(version: str): |
| | if version == "1.0": |
| | vae_scale_factor_spatial = 8 |
| | init_kwargs = { |
| | "patch_size": 2, |
| | "patch_size_t": None, |
| | "patch_bias": True, |
| | "sample_height": 480 // vae_scale_factor_spatial, |
| | "sample_width": 720 // vae_scale_factor_spatial, |
| | "sample_frames": 49, |
| | } |
| |
|
| | elif version == "1.5": |
| | vae_scale_factor_spatial = 8 |
| | init_kwargs = { |
| | "patch_size": 2, |
| | "patch_size_t": 2, |
| | "patch_bias": False, |
| | "sample_height": 300, |
| | "sample_width": 300, |
| | "sample_frames": 81, |
| | } |
| | else: |
| | raise ValueError("Unsupported version of CogVideoX.") |
| |
|
| | return init_kwargs |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" |
| | ) |
| | parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") |
| | parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") |
| | parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") |
| | parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") |
| | parser.add_argument( |
| | "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" |
| | ) |
| | parser.add_argument( |
| | "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" |
| | ) |
| | parser.add_argument( |
| | "--typecast_text_encoder", |
| | action="store_true", |
| | default=False, |
| | help="Whether or not to apply fp16/bf16 precision to text_encoder", |
| | ) |
| | |
| | parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") |
| | |
| | parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") |
| | |
| | parser.add_argument( |
| | "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" |
| | ) |
| | |
| | parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") |
| | |
| | parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") |
| | parser.add_argument( |
| | "--i2v", |
| | action="store_true", |
| | default=False, |
| | help="Whether the model to be converted is the Image-to-Video version of CogVideoX.", |
| | ) |
| | parser.add_argument( |
| | "--version", |
| | choices=["1.0", "1.5"], |
| | default="1.0", |
| | help="Which version of CogVideoX to use for initializing default modeling parameters.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = get_args() |
| |
|
| | transformer = None |
| | vae = None |
| |
|
| | if args.fp16 and args.bf16: |
| | raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.") |
| |
|
| | dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 |
| |
|
| | if args.transformer_ckpt_path is not None: |
| | init_kwargs = get_transformer_init_kwargs(args.version) |
| | transformer = convert_transformer( |
| | args.transformer_ckpt_path, |
| | args.num_layers, |
| | args.num_attention_heads, |
| | args.use_rotary_positional_embeddings, |
| | args.i2v, |
| | dtype, |
| | init_kwargs, |
| | ) |
| | if args.vae_ckpt_path is not None: |
| | |
| | vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32) |
| |
|
| | text_encoder_id = "google/t5-v1_1-xxl" |
| | tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) |
| | text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) |
| |
|
| | if args.typecast_text_encoder: |
| | text_encoder = text_encoder.to(dtype=dtype) |
| |
|
| | |
| | for param in text_encoder.parameters(): |
| | param.data = param.data.contiguous() |
| |
|
| | scheduler = CogVideoXDDIMScheduler.from_config( |
| | { |
| | "snr_shift_scale": args.snr_shift_scale, |
| | "beta_end": 0.012, |
| | "beta_schedule": "scaled_linear", |
| | "beta_start": 0.00085, |
| | "clip_sample": False, |
| | "num_train_timesteps": 1000, |
| | "prediction_type": "v_prediction", |
| | "rescale_betas_zero_snr": True, |
| | "set_alpha_to_one": True, |
| | "timestep_spacing": "trailing", |
| | } |
| | ) |
| | if args.i2v: |
| | pipeline_cls = CogVideoXImageToVideoPipeline |
| | else: |
| | pipeline_cls = CogVideoXPipeline |
| |
|
| | pipe = pipeline_cls( |
| | tokenizer=tokenizer, |
| | text_encoder=text_encoder, |
| | vae=vae, |
| | transformer=transformer, |
| | scheduler=scheduler, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) |
| |
|