File size: 6,790 Bytes
70764d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import argparse
import re
import torch
import safetensors.torch
def convert_mm_name_to_compvis(key):
sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
sd_module_key = sd_module_key.replace(".", "_")
return f'{sd_module_key}.lora_{network_part}'
def convert_from_diffuser_state_dict(ad_cn_l):
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for i in range(4):
# loop over downblocks/upblocks
for j in range(10):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
# controlnet specific
controlnet_cond_embedding_names = ['conv_in'] + [f'blocks.{i}' for i in range(6)] + ['conv_out']
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
sd_prefix = f"input_hint_block.{i*2}."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
for i in range(12):
hf_prefix = f"controlnet_down_blocks.{i}."
sd_prefix = f"zero_convs.{i}.0."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
def _convert_from_diffuser_state_dict(unet_state_dict):
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items() if k in unet_state_dict}
return new_state_dict
return _convert_from_diffuser_state_dict(ad_cn_l)
def lora_conversion(file_path, save_path):
state_dict = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(file_path)
modified_dict = {convert_mm_name_to_compvis(k): v for k, v in state_dict.items()}
safetensors.torch.save_file(modified_dict, save_path)
print(f"LoRA conversion completed: {save_path}")
def controlnet_conversion(ad_cn_old, ad_cn_new, normal_cn_path):
ad_cn = safetensors.torch.load_file(ad_cn_old) if ad_cn_old.endswith(".safetensors") else torch.load(ad_cn_old)
normal_cn = safetensors.torch.load_file(normal_cn_path)
ad_cn_l, ad_cn_m = {}, {}
for k in ad_cn.keys():
if k.startswith("controlnet_cond_embedding"):
new_key = k.replace("controlnet_cond_embedding.", "input_hint_block.0.")
ad_cn_m[new_key] = ad_cn[k].to(torch.float16)
elif not k in normal_cn:
if "motion_modules" in k:
ad_cn_m[k] = ad_cn[k].to(torch.float16)
else:
raise Exception(f"{k} not in normal_cn")
else:
ad_cn_l[k] = ad_cn[k].to(torch.float16)
ad_cn_l = convert_from_diffuser_state_dict(ad_cn_l)
ad_cn_l.update(ad_cn_m)
safetensors.torch.save_file(ad_cn_l, ad_cn_new)
print(f"ControlNet conversion completed: {ad_cn_new}")
def main():
parser = argparse.ArgumentParser(description="Script to convert LoRA and ControlNet models.")
subparsers = parser.add_subparsers(dest='command')
# LoRA conversion parser
lora_parser = subparsers.add_parser('lora', help='LoRA conversion')
lora_parser.add_argument('file_path', type=str, help='Path to the old LoRA checkpoint')
lora_parser.add_argument('save_path', type=str, help='Path to save the new LoRA checkpoint')
# ControlNet conversion parser
cn_parser = subparsers.add_parser('controlnet', help='ControlNet conversion')
cn_parser.add_argument('ad_cn_old', type=str, help='Path to the old sparse ControlNet checkpoint')
cn_parser.add_argument('ad_cn_new', type=str, help='Path to save the new sparse ControlNet checkpoint')
cn_parser.add_argument('normal_cn_path', type=str, help='Path to the normal ControlNet model')
args = parser.parse_args()
if args.command == 'lora':
lora_conversion(args.file_path, args.save_path)
elif args.command == 'controlnet':
controlnet_conversion(args.ad_cn_old, args.ad_cn_new, args.normal_cn_path)
else:
parser.print_help()
if __name__ == "__main__":
main()
|