|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """Conversion script for the LDM checkpoints.""" | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import json | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from transformers.file_utils import has_file | 
					
						
						|  |  | 
					
						
						|  | from diffusers import UNet2DConditionModel, UNet2DModel | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | do_only_config = False | 
					
						
						|  | do_only_weights = True | 
					
						
						|  | do_only_renaming = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | "--repo_path", | 
					
						
						|  | default=None, | 
					
						
						|  | type=str, | 
					
						
						|  | required=True, | 
					
						
						|  | help="The config json file corresponding to the architecture.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | 
					
						
						|  |  | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | config_parameters_to_change = { | 
					
						
						|  | "image_size": "sample_size", | 
					
						
						|  | "num_res_blocks": "layers_per_block", | 
					
						
						|  | "block_channels": "block_out_channels", | 
					
						
						|  | "down_blocks": "down_block_types", | 
					
						
						|  | "up_blocks": "up_block_types", | 
					
						
						|  | "downscale_freq_shift": "freq_shift", | 
					
						
						|  | "resnet_num_groups": "norm_num_groups", | 
					
						
						|  | "resnet_act_fn": "act_fn", | 
					
						
						|  | "resnet_eps": "norm_eps", | 
					
						
						|  | "num_head_channels": "attention_head_dim", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | key_parameters_to_change = { | 
					
						
						|  | "time_steps": "time_proj", | 
					
						
						|  | "mid": "mid_block", | 
					
						
						|  | "downsample_blocks": "down_blocks", | 
					
						
						|  | "upsample_blocks": "up_blocks", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | subfolder = "" if has_file(args.repo_path, "config.json") else "unet" | 
					
						
						|  |  | 
					
						
						|  | with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader: | 
					
						
						|  | text = reader.read() | 
					
						
						|  | config = json.loads(text) | 
					
						
						|  |  | 
					
						
						|  | if do_only_config: | 
					
						
						|  | for key in config_parameters_to_change.keys(): | 
					
						
						|  | config.pop(key, None) | 
					
						
						|  |  | 
					
						
						|  | if has_file(args.repo_path, "config.json"): | 
					
						
						|  | model = UNet2DModel(**config) | 
					
						
						|  | else: | 
					
						
						|  | class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel | 
					
						
						|  | model = class_name(**config) | 
					
						
						|  |  | 
					
						
						|  | if do_only_config: | 
					
						
						|  | model.save_config(os.path.join(args.repo_path, subfolder)) | 
					
						
						|  |  | 
					
						
						|  | config = dict(model.config) | 
					
						
						|  |  | 
					
						
						|  | if do_only_renaming: | 
					
						
						|  | for key, value in config_parameters_to_change.items(): | 
					
						
						|  | if key in config: | 
					
						
						|  | config[value] = config[key] | 
					
						
						|  | del config[key] | 
					
						
						|  |  | 
					
						
						|  | config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]] | 
					
						
						|  | config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]] | 
					
						
						|  |  | 
					
						
						|  | if do_only_weights: | 
					
						
						|  | state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin")) | 
					
						
						|  |  | 
					
						
						|  | new_state_dict = {} | 
					
						
						|  | for param_key, param_value in state_dict.items(): | 
					
						
						|  | if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"): | 
					
						
						|  | continue | 
					
						
						|  | has_changed = False | 
					
						
						|  | for key, new_key in key_parameters_to_change.items(): | 
					
						
						|  | if not has_changed and param_key.split(".")[0] == key: | 
					
						
						|  | new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value | 
					
						
						|  | has_changed = True | 
					
						
						|  | if not has_changed: | 
					
						
						|  | new_state_dict[param_key] = param_value | 
					
						
						|  |  | 
					
						
						|  | model.load_state_dict(new_state_dict) | 
					
						
						|  | model.save_pretrained(os.path.join(args.repo_path, subfolder)) | 
					
						
						|  |  |