doc
Browse files- README.md +19 -48
- convert.py +158 -0
README.md
CHANGED
@@ -1,62 +1,33 @@
|
|
1 |
# AnimateDiff Model Checkpoints for A1111 SD WebUI
|
2 |
This repository saves all AnimateDiff models in fp16 & safetensors format for A1111 AnimateDiff users, including
|
3 |
- motion module (v1-v3)
|
4 |
-
- [motion LoRA](#motion-lora) (v2 only, use like any other
|
5 |
-
- domain adapter (v3 only, use like any other
|
6 |
-
- [sparse ControlNet](#sparse-controlnet) (v3 only, use like any other
|
7 |
|
8 |
Unless specified below, you are fine to use models from the [official model repository](https://huggingface.co/guoyww/animatediff/tree/main). I will only convert state dict keys if absolutely necessary.
|
9 |
|
10 |
|
11 |
## Motion LoRA
|
12 |
-
Put
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
-
def convert_mm_name_to_compvis(key):
|
22 |
-
sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
|
23 |
-
sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
|
24 |
-
sd_module_key = sd_module_key.replace(".", "_")
|
25 |
-
return f'{sd_module_key}.lora_{network_part}'
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
state_dict = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(file_path)
|
30 |
-
modified_dict = {convert_mm_name_to_compvis(k): v for k, v in state_dict.items()}
|
31 |
-
safetensors.torch.save_file(modified_dict, save_path)
|
32 |
-
```
|
33 |
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
import safetensors.torch
|
42 |
-
ad_cn_old = "v3_sd15_sparsectrl_scribble.ckpt" # replace with path to your own old sparse ControlNet checkpoint
|
43 |
-
ad_cn_new = "mm_sd15_v3_sparsectrl_scribble.safetensors" # replace with path to your own new sparse ControlNet checkpoint
|
44 |
-
normal_cn_path = "diffusion_pytorch_model.fp16.safetensors" # download https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/diffusion_pytorch_model.fp16.safetensors?download=true and replace with the path to this model
|
45 |
-
ad_cn = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(ad_cn_old)
|
46 |
-
normal_cn = safetensors.torch.load_file(normal_cn_path)
|
47 |
-
ad_cn_l, ad_cn_m = {}, {}
|
48 |
-
for k in ad_cn.keys():
|
49 |
-
if k.startswith("controlnet_cond_embedding"):
|
50 |
-
new_key = k.replace("controlnet_cond_embedding.", "input_hint_block.0.")
|
51 |
-
ad_cn_m[new_key] = ad_cn[k].to(torch.float16)
|
52 |
-
elif not k in normal_cn:
|
53 |
-
if "motion_modules" in k:
|
54 |
-
ad_cn_m[k] = ad_cn[k].to(torch.float16)
|
55 |
-
else:
|
56 |
-
raise Exception(f"{k} not in normal_cn")
|
57 |
-
else:
|
58 |
-
ad_cn_l[k] = ad_cn[k].to(torch.float16)
|
59 |
-
ad_cn_l = convert_from_diffuser_state_dict(ad_cn_l)
|
60 |
-
ad_cn_l.update(ad_cn_m)
|
61 |
-
safetensors.torch.save_file(ad_cn_l, ad_cn_new)
|
62 |
-
```
|
|
|
1 |
# AnimateDiff Model Checkpoints for A1111 SD WebUI
|
2 |
This repository saves all AnimateDiff models in fp16 & safetensors format for A1111 AnimateDiff users, including
|
3 |
- motion module (v1-v3)
|
4 |
+
- [motion LoRA](#motion-lora) (v2 only, use like any other LoRA)
|
5 |
+
- domain adapter (v3 only, use like any other LoRA)
|
6 |
+
- [sparse ControlNet](#sparse-controlnet) (v3 only, use like any other ControlNet)
|
7 |
|
8 |
Unless specified below, you are fine to use models from the [official model repository](https://huggingface.co/guoyww/animatediff/tree/main). I will only convert state dict keys if absolutely necessary.
|
9 |
|
10 |
|
11 |
## Motion LoRA
|
12 |
+
Put Motion LoRAs to `stable-diffusion-webui/models/Lora` and use Motion LoRAs like any other LoRA you use.
|
13 |
|
14 |
+
[lora_v2](lora_v2) contains motion LoRAs for AnimateDiff-A1111 v2.0.0. I converted state dict keys inside motion LoRAs. Originlal motion LoRAs won't work for AnimateDiff-A1111 v2.0.0 and later due to maintenance reason.
|
15 |
|
16 |
+
Use [convert.py](convert.py) in the following way if you want to convert a third-party motion LoRA to be compatible with A1111:
|
17 |
+
- Activate your A1111 Python environment first.
|
18 |
+
- Command: `python script.py lora [file_path] [save_path]`
|
19 |
+
- Replace `[file_path]` with the path to your old LoRA checkpoint.
|
20 |
+
- Replace `[save_path]` with the path where you want to save the new LoRA checkpoint.
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
## Sparse ControlNet
|
24 |
+
Put Sparse ControlNets to `stable-diffusion-webui/models/ControlNet` and use Sparse ControlNets like any other ControlNet you use.
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
Like Motion LoRA, I converted state dict keys inside sparse ControlNet. Original sparse ControlNets won't work for A1111 due to maintenance reason.
|
27 |
|
28 |
+
Use [convert.py](convert.py) in the following way if you want to convert a third-party sparse ControlNet to be compatible with A1111:
|
29 |
+
- Activate your A1111 Python environment first.
|
30 |
+
- Command: `python script.py controlnet [ad_cn_old] [ad_cn_new] [normal_cn_path]`
|
31 |
+
- Replace `[ad_cn_old]` with the path to your old sparse ControlNet checkpoint.
|
32 |
+
- Replace `[ad_cn_new]` with the path where you want to save the new sparse ControlNet checkpoint.
|
33 |
+
- Replace `[normal_cn_path]` with the path to the normal ControlNet model. Download normal ControlNet from [here](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/diffusion_pytorch_model.fp16.safetensors?download=true).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
convert.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import safetensors.torch
|
5 |
+
|
6 |
+
|
7 |
+
def convert_mm_name_to_compvis(key):
|
8 |
+
sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
|
9 |
+
sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
|
10 |
+
sd_module_key = sd_module_key.replace(".", "_")
|
11 |
+
return f'{sd_module_key}.lora_{network_part}'
|
12 |
+
|
13 |
+
def convert_from_diffuser_state_dict(ad_cn_l):
|
14 |
+
unet_conversion_map = [
|
15 |
+
# (stable-diffusion, HF Diffusers)
|
16 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
17 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
18 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
19 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
20 |
+
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
21 |
+
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
22 |
+
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
23 |
+
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
24 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
25 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
26 |
+
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
27 |
+
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
28 |
+
]
|
29 |
+
|
30 |
+
unet_conversion_map_resnet = [
|
31 |
+
# (stable-diffusion, HF Diffusers)
|
32 |
+
("in_layers.0", "norm1"),
|
33 |
+
("in_layers.2", "conv1"),
|
34 |
+
("out_layers.0", "norm2"),
|
35 |
+
("out_layers.3", "conv2"),
|
36 |
+
("emb_layers.1", "time_emb_proj"),
|
37 |
+
("skip_connection", "conv_shortcut"),
|
38 |
+
]
|
39 |
+
|
40 |
+
unet_conversion_map_layer = []
|
41 |
+
# hardcoded number of downblocks and resnets/attentions...
|
42 |
+
# would need smarter logic for other networks.
|
43 |
+
for i in range(4):
|
44 |
+
# loop over downblocks/upblocks
|
45 |
+
|
46 |
+
for j in range(10):
|
47 |
+
# loop over resnets/attentions for downblocks
|
48 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
49 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
50 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
51 |
+
|
52 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
53 |
+
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
54 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
55 |
+
|
56 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
57 |
+
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
58 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
59 |
+
|
60 |
+
|
61 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
62 |
+
sd_mid_atn_prefix = "middle_block.1."
|
63 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
64 |
+
|
65 |
+
for j in range(2):
|
66 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
67 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
68 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
69 |
+
|
70 |
+
# controlnet specific
|
71 |
+
|
72 |
+
controlnet_cond_embedding_names = ['conv_in'] + [f'blocks.{i}' for i in range(6)] + ['conv_out']
|
73 |
+
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
74 |
+
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
75 |
+
sd_prefix = f"input_hint_block.{i*2}."
|
76 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
77 |
+
|
78 |
+
for i in range(12):
|
79 |
+
hf_prefix = f"controlnet_down_blocks.{i}."
|
80 |
+
sd_prefix = f"zero_convs.{i}.0."
|
81 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
82 |
+
|
83 |
+
|
84 |
+
def _convert_from_diffuser_state_dict(unet_state_dict):
|
85 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
86 |
+
for sd_name, hf_name in unet_conversion_map:
|
87 |
+
mapping[hf_name] = sd_name
|
88 |
+
for k, v in mapping.items():
|
89 |
+
if "resnets" in k:
|
90 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
91 |
+
v = v.replace(hf_part, sd_part)
|
92 |
+
mapping[k] = v
|
93 |
+
for k, v in mapping.items():
|
94 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
95 |
+
v = v.replace(hf_part, sd_part)
|
96 |
+
mapping[k] = v
|
97 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items() if k in unet_state_dict}
|
98 |
+
return new_state_dict
|
99 |
+
|
100 |
+
return _convert_from_diffuser_state_dict(ad_cn_l)
|
101 |
+
|
102 |
+
|
103 |
+
def lora_conversion(file_path, save_path):
|
104 |
+
state_dict = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(file_path)
|
105 |
+
modified_dict = {convert_mm_name_to_compvis(k): v for k, v in state_dict.items()}
|
106 |
+
safetensors.torch.save_file(modified_dict, save_path)
|
107 |
+
print(f"LoRA conversion completed: {save_path}")
|
108 |
+
|
109 |
+
|
110 |
+
def controlnet_conversion(ad_cn_old, ad_cn_new, normal_cn_path):
|
111 |
+
ad_cn = safetensors.torch.load_file(ad_cn_old) if ad_cn_old.endswith(".safetensors") else torch.load(ad_cn_old)
|
112 |
+
normal_cn = safetensors.torch.load_file(normal_cn_path)
|
113 |
+
ad_cn_l, ad_cn_m = {}, {}
|
114 |
+
|
115 |
+
for k in ad_cn.keys():
|
116 |
+
if k.startswith("controlnet_cond_embedding"):
|
117 |
+
new_key = k.replace("controlnet_cond_embedding.", "input_hint_block.0.")
|
118 |
+
ad_cn_m[new_key] = ad_cn[k].to(torch.float16)
|
119 |
+
elif not k in normal_cn:
|
120 |
+
if "motion_modules" in k:
|
121 |
+
ad_cn_m[k] = ad_cn[k].to(torch.float16)
|
122 |
+
else:
|
123 |
+
raise Exception(f"{k} not in normal_cn")
|
124 |
+
else:
|
125 |
+
ad_cn_l[k] = ad_cn[k].to(torch.float16)
|
126 |
+
|
127 |
+
ad_cn_l = convert_from_diffuser_state_dict(ad_cn_l)
|
128 |
+
ad_cn_l.update(ad_cn_m)
|
129 |
+
safetensors.torch.save_file(ad_cn_l, ad_cn_new)
|
130 |
+
print(f"ControlNet conversion completed: {ad_cn_new}")
|
131 |
+
|
132 |
+
|
133 |
+
def main():
|
134 |
+
parser = argparse.ArgumentParser(description="Script to convert LoRA and ControlNet models.")
|
135 |
+
subparsers = parser.add_subparsers(dest='command')
|
136 |
+
|
137 |
+
# LoRA conversion parser
|
138 |
+
lora_parser = subparsers.add_parser('lora', help='LoRA conversion')
|
139 |
+
lora_parser.add_argument('file_path', type=str, help='Path to the old LoRA checkpoint')
|
140 |
+
lora_parser.add_argument('save_path', type=str, help='Path to save the new LoRA checkpoint')
|
141 |
+
|
142 |
+
# ControlNet conversion parser
|
143 |
+
cn_parser = subparsers.add_parser('controlnet', help='ControlNet conversion')
|
144 |
+
cn_parser.add_argument('ad_cn_old', type=str, help='Path to the old sparse ControlNet checkpoint')
|
145 |
+
cn_parser.add_argument('ad_cn_new', type=str, help='Path to save the new sparse ControlNet checkpoint')
|
146 |
+
cn_parser.add_argument('normal_cn_path', type=str, help='Path to the normal ControlNet model')
|
147 |
+
|
148 |
+
args = parser.parse_args()
|
149 |
+
|
150 |
+
if args.command == 'lora':
|
151 |
+
lora_conversion(args.file_path, args.save_path)
|
152 |
+
elif args.command == 'controlnet':
|
153 |
+
controlnet_conversion(args.ad_cn_old, args.ad_cn_new, args.normal_cn_path)
|
154 |
+
else:
|
155 |
+
parser.print_help()
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
main()
|