| | import argparse |
| | from typing import Any, Dict |
| |
|
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| |
|
| | from diffusers import AutoencoderDC |
| |
|
| |
|
| | def remap_qkv_(key: str, state_dict: Dict[str, Any]): |
| | qkv = state_dict.pop(key) |
| | q, k, v = torch.chunk(qkv, 3, dim=0) |
| | parent_module, _, _ = key.rpartition(".qkv.conv.weight") |
| | state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() |
| | state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() |
| | state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() |
| |
|
| |
|
| | def remap_proj_conv_(key: str, state_dict: Dict[str, Any]): |
| | parent_module, _, _ = key.rpartition(".proj.conv.weight") |
| | state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() |
| |
|
| |
|
| | AE_KEYS_RENAME_DICT = { |
| | |
| | "main.": "", |
| | "op_list.": "", |
| | "context_module": "attn", |
| | "local_module": "conv_out", |
| | |
| | |
| | "aggreg.0.0": "to_qkv_multiscale.0.proj_in", |
| | "aggreg.0.1": "to_qkv_multiscale.0.proj_out", |
| | "depth_conv.conv": "conv_depth", |
| | "inverted_conv.conv": "conv_inverted", |
| | "point_conv.conv": "conv_point", |
| | "point_conv.norm": "norm", |
| | "conv.conv.": "conv.", |
| | "conv1.conv": "conv1", |
| | "conv2.conv": "conv2", |
| | "conv2.norm": "norm", |
| | "proj.norm": "norm_out", |
| | |
| | "encoder.project_in.conv": "encoder.conv_in", |
| | "encoder.project_out.0.conv": "encoder.conv_out", |
| | "encoder.stages": "encoder.down_blocks", |
| | |
| | "decoder.project_in.conv": "decoder.conv_in", |
| | "decoder.project_out.0": "decoder.norm_out", |
| | "decoder.project_out.2.conv": "decoder.conv_out", |
| | "decoder.stages": "decoder.up_blocks", |
| | } |
| |
|
| | AE_F32C32_KEYS = { |
| | |
| | "encoder.project_in.conv": "encoder.conv_in.conv", |
| | |
| | "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| | } |
| |
|
| | AE_F64C128_KEYS = { |
| | |
| | "encoder.project_in.conv": "encoder.conv_in.conv", |
| | |
| | "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| | } |
| |
|
| | AE_F128C512_KEYS = { |
| | |
| | "encoder.project_in.conv": "encoder.conv_in.conv", |
| | |
| | "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| | } |
| |
|
| | AE_SPECIAL_KEYS_REMAP = { |
| | "qkv.conv.weight": remap_qkv_, |
| | "proj.conv.weight": remap_proj_conv_, |
| | } |
| |
|
| |
|
| | def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: |
| | state_dict = saved_dict |
| | if "model" in saved_dict.keys(): |
| | state_dict = state_dict["model"] |
| | if "module" in saved_dict.keys(): |
| | state_dict = state_dict["module"] |
| | if "state_dict" in saved_dict.keys(): |
| | state_dict = state_dict["state_dict"] |
| | return state_dict |
| |
|
| |
|
| | def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: |
| | state_dict[new_key] = state_dict.pop(old_key) |
| |
|
| |
|
| | def convert_ae(config_name: str, dtype: torch.dtype): |
| | config = get_ae_config(config_name) |
| | hub_id = f"mit-han-lab/{config_name}" |
| | ckpt_path = hf_hub_download(hub_id, "model.safetensors") |
| | original_state_dict = get_state_dict(load_file(ckpt_path)) |
| |
|
| | ae = AutoencoderDC(**config).to(dtype=dtype) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | new_key = key[:] |
| | for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): |
| | new_key = new_key.replace(replace_key, rename_key) |
| | update_state_dict_(original_state_dict, key, new_key) |
| |
|
| | for key in list(original_state_dict.keys()): |
| | for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): |
| | if special_key not in key: |
| | continue |
| | handler_fn_inplace(key, original_state_dict) |
| |
|
| | ae.load_state_dict(original_state_dict, strict=True) |
| | return ae |
| |
|
| |
|
| | def get_ae_config(name: str): |
| | if name in ["dc-ae-f32c32-sana-1.0"]: |
| | config = { |
| | "latent_channels": 32, |
| | "encoder_block_types": ( |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ), |
| | "decoder_block_types": ( |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ), |
| | "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
| | "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
| | "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
| | "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
| | "encoder_layers_per_block": (2, 2, 2, 3, 3, 3), |
| | "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], |
| | "downsample_block_type": "conv", |
| | "upsample_block_type": "interpolate", |
| | "decoder_norm_types": "rms_norm", |
| | "decoder_act_fns": "silu", |
| | "scaling_factor": 0.41407, |
| | } |
| | elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: |
| | AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) |
| | config = { |
| | "latent_channels": 32, |
| | "encoder_block_types": [ |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ], |
| | "decoder_block_types": [ |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ], |
| | "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| | "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| | "encoder_layers_per_block": [0, 4, 8, 2, 2, 2], |
| | "decoder_layers_per_block": [0, 5, 10, 2, 2, 2], |
| | "encoder_qkv_multiscales": ((), (), (), (), (), ()), |
| | "decoder_qkv_multiscales": ((), (), (), (), (), ()), |
| | "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], |
| | "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], |
| | } |
| | if name == "dc-ae-f32c32-in-1.0": |
| | config["scaling_factor"] = 0.3189 |
| | elif name == "dc-ae-f32c32-mix-1.0": |
| | config["scaling_factor"] = 0.4552 |
| | elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: |
| | AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) |
| | config = { |
| | "latent_channels": 128, |
| | "encoder_block_types": [ |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ], |
| | "decoder_block_types": [ |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ], |
| | "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
| | "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
| | "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], |
| | "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], |
| | "encoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
| | "decoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
| | "decoder_norm_types": [ |
| | "batch_norm", |
| | "batch_norm", |
| | "batch_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | ], |
| | "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], |
| | } |
| | if name == "dc-ae-f64c128-in-1.0": |
| | config["scaling_factor"] = 0.2889 |
| | elif name == "dc-ae-f64c128-mix-1.0": |
| | config["scaling_factor"] = 0.4538 |
| | elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: |
| | AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) |
| | config = { |
| | "latent_channels": 512, |
| | "encoder_block_types": [ |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ], |
| | "decoder_block_types": [ |
| | "ResBlock", |
| | "ResBlock", |
| | "ResBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | "EfficientViTBlock", |
| | ], |
| | "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
| | "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
| | "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], |
| | "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], |
| | "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
| | "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
| | "decoder_norm_types": [ |
| | "batch_norm", |
| | "batch_norm", |
| | "batch_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | "rms_norm", |
| | ], |
| | "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], |
| | } |
| | if name == "dc-ae-f128c512-in-1.0": |
| | config["scaling_factor"] = 0.4883 |
| | elif name == "dc-ae-f128c512-mix-1.0": |
| | config["scaling_factor"] = 0.3620 |
| | else: |
| | raise ValueError("Invalid config name provided.") |
| |
|
| | return config |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--config_name", |
| | type=str, |
| | default="dc-ae-f32c32-sana-1.0", |
| | choices=[ |
| | "dc-ae-f32c32-sana-1.0", |
| | "dc-ae-f32c32-in-1.0", |
| | "dc-ae-f32c32-mix-1.0", |
| | "dc-ae-f64c128-in-1.0", |
| | "dc-ae-f64c128-mix-1.0", |
| | "dc-ae-f128c512-in-1.0", |
| | "dc-ae-f128c512-mix-1.0", |
| | ], |
| | help="The DCAE checkpoint to convert", |
| | ) |
| | parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") |
| | parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") |
| | return parser.parse_args() |
| |
|
| |
|
| | DTYPE_MAPPING = { |
| | "fp32": torch.float32, |
| | "fp16": torch.float16, |
| | "bf16": torch.bfloat16, |
| | } |
| |
|
| | VARIANT_MAPPING = { |
| | "fp32": None, |
| | "fp16": "fp16", |
| | "bf16": "bf16", |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = get_args() |
| |
|
| | dtype = DTYPE_MAPPING[args.dtype] |
| | variant = VARIANT_MAPPING[args.dtype] |
| |
|
| | ae = convert_ae(args.config_name, dtype) |
| | ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) |
| |
|