Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import time | |
from pathlib import Path | |
from typing import Any, Dict, Literal | |
import torch | |
from diffusers import AsymmetricAutoencoderKL | |
ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = { | |
"in_channels": 3, | |
"out_channels": 3, | |
"down_block_types": [ | |
"DownEncoderBlock2D", | |
"DownEncoderBlock2D", | |
"DownEncoderBlock2D", | |
"DownEncoderBlock2D", | |
], | |
"down_block_out_channels": [128, 256, 512, 512], | |
"layers_per_down_block": 2, | |
"up_block_types": [ | |
"UpDecoderBlock2D", | |
"UpDecoderBlock2D", | |
"UpDecoderBlock2D", | |
"UpDecoderBlock2D", | |
], | |
"up_block_out_channels": [192, 384, 768, 768], | |
"layers_per_up_block": 3, | |
"act_fn": "silu", | |
"latent_channels": 4, | |
"norm_num_groups": 32, | |
"sample_size": 256, | |
"scaling_factor": 0.18215, | |
} | |
ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = { | |
"in_channels": 3, | |
"out_channels": 3, | |
"down_block_types": [ | |
"DownEncoderBlock2D", | |
"DownEncoderBlock2D", | |
"DownEncoderBlock2D", | |
"DownEncoderBlock2D", | |
], | |
"down_block_out_channels": [128, 256, 512, 512], | |
"layers_per_down_block": 2, | |
"up_block_types": [ | |
"UpDecoderBlock2D", | |
"UpDecoderBlock2D", | |
"UpDecoderBlock2D", | |
"UpDecoderBlock2D", | |
], | |
"up_block_out_channels": [256, 512, 1024, 1024], | |
"layers_per_up_block": 5, | |
"act_fn": "silu", | |
"latent_channels": 4, | |
"norm_num_groups": 32, | |
"sample_size": 256, | |
"scaling_factor": 0.18215, | |
} | |
def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]: | |
converted_state_dict = {} | |
for k, v in original_state_dict.items(): | |
if k.startswith("encoder."): | |
converted_state_dict[ | |
k.replace("encoder.down.", "encoder.down_blocks.") | |
.replace("encoder.mid.", "encoder.mid_block.") | |
.replace("encoder.norm_out.", "encoder.conv_norm_out.") | |
.replace(".downsample.", ".downsamplers.0.") | |
.replace(".nin_shortcut.", ".conv_shortcut.") | |
.replace(".block.", ".resnets.") | |
.replace(".block_1.", ".resnets.0.") | |
.replace(".block_2.", ".resnets.1.") | |
.replace(".attn_1.k.", ".attentions.0.to_k.") | |
.replace(".attn_1.q.", ".attentions.0.to_q.") | |
.replace(".attn_1.v.", ".attentions.0.to_v.") | |
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.") | |
.replace(".attn_1.norm.", ".attentions.0.group_norm.") | |
] = v | |
elif k.startswith("decoder.") and "up_layers" not in k: | |
converted_state_dict[ | |
k.replace("decoder.encoder.", "decoder.condition_encoder.") | |
.replace(".norm_out.", ".conv_norm_out.") | |
.replace(".up.0.", ".up_blocks.3.") | |
.replace(".up.1.", ".up_blocks.2.") | |
.replace(".up.2.", ".up_blocks.1.") | |
.replace(".up.3.", ".up_blocks.0.") | |
.replace(".block.", ".resnets.") | |
.replace("mid", "mid_block") | |
.replace(".0.upsample.", ".0.upsamplers.0.") | |
.replace(".1.upsample.", ".1.upsamplers.0.") | |
.replace(".2.upsample.", ".2.upsamplers.0.") | |
.replace(".nin_shortcut.", ".conv_shortcut.") | |
.replace(".block_1.", ".resnets.0.") | |
.replace(".block_2.", ".resnets.1.") | |
.replace(".attn_1.k.", ".attentions.0.to_k.") | |
.replace(".attn_1.q.", ".attentions.0.to_q.") | |
.replace(".attn_1.v.", ".attentions.0.to_v.") | |
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.") | |
.replace(".attn_1.norm.", ".attentions.0.group_norm.") | |
] = v | |
elif k.startswith("quant_conv."): | |
converted_state_dict[k] = v | |
elif k.startswith("post_quant_conv."): | |
converted_state_dict[k] = v | |
else: | |
print(f" skipping key `{k}`") | |
# fix weights shape | |
for k, v in converted_state_dict.items(): | |
if ( | |
(k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0")) | |
and k.endswith("weight") | |
and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k) | |
): | |
converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0] | |
return converted_state_dict | |
def get_asymmetric_autoencoder_kl_from_original_checkpoint( | |
scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device | |
) -> AsymmetricAutoencoderKL: | |
print("Loading original state_dict") | |
original_state_dict = torch.load(original_checkpoint_path, map_location=map_location) | |
original_state_dict = original_state_dict["state_dict"] | |
print("Converting state_dict") | |
converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict) | |
kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG | |
print("Initializing AsymmetricAutoencoderKL model") | |
asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs) | |
print("Loading weight from converted state_dict") | |
asymmetric_autoencoder_kl.load_state_dict(converted_state_dict) | |
asymmetric_autoencoder_kl.eval() | |
print("AsymmetricAutoencoderKL successfully initialized") | |
return asymmetric_autoencoder_kl | |
if __name__ == "__main__": | |
start = time.time() | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--scale", | |
default=None, | |
type=str, | |
required=True, | |
help="Asymmetric VQGAN scale: `1.5` or `2`", | |
) | |
parser.add_argument( | |
"--original_checkpoint_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to the original Asymmetric VQGAN checkpoint", | |
) | |
parser.add_argument( | |
"--output_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to save pretrained AsymmetricAutoencoderKL model", | |
) | |
parser.add_argument( | |
"--map_location", | |
default="cpu", | |
type=str, | |
required=False, | |
help="The device passed to `map_location` when loading the checkpoint", | |
) | |
args = parser.parse_args() | |
assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`" | |
assert Path(args.original_checkpoint_path).is_file() | |
asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint( | |
scale=args.scale, | |
original_checkpoint_path=args.original_checkpoint_path, | |
map_location=torch.device(args.map_location), | |
) | |
print("Saving pretrained AsymmetricAutoencoderKL") | |
asymmetric_autoencoder_kl.save_pretrained(args.output_path) | |
print(f"Done in {time.time() - start:.2f} seconds") | |