conrevo commited on
Commit
70764d1
1 Parent(s): afd8b4f
Files changed (2) hide show
  1. README.md +19 -48
  2. convert.py +158 -0
README.md CHANGED
@@ -1,62 +1,33 @@
1
  # AnimateDiff Model Checkpoints for A1111 SD WebUI
2
  This repository saves all AnimateDiff models in fp16 & safetensors format for A1111 AnimateDiff users, including
3
  - motion module (v1-v3)
4
- - [motion LoRA](#motion-lora) (v2 only, use like any other LoRAs)
5
- - domain adapter (v3 only, use like any other LoRAs)
6
- - [sparse ControlNet](#sparse-controlnet) (v3 only, use like any other ControlNets)
7
 
8
  Unless specified below, you are fine to use models from the [official model repository](https://huggingface.co/guoyww/animatediff/tree/main). I will only convert state dict keys if absolutely necessary.
9
 
10
 
11
  ## Motion LoRA
12
- Put motion LoRAs to `stable-diffusion-webui/models/Lora` and use motion LoRAs like any other LoRAs you use.
13
 
14
- `lora_v2` contains motion LoRAs for AnimateDiff-A1111 v2.0.0. Old motion LoRAs won't work for v2.0.0 and later due to maintenance reason. `lora` will be removed after AnimateDiff-A1111 v2.0.0 is released to master branch.
15
 
16
- I converted the original state dict via the following code. You may do so if you want to use a motion LoRA from community. Run the following script to make your own motion LoRA checkpoint compatible with AnimateDiff-A1111 v2.0.0 and later.
17
- ```python
18
- import os, re, torch
19
- import safetensors.torch
 
20
 
21
- def convert_mm_name_to_compvis(key):
22
- sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
23
- sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
24
- sd_module_key = sd_module_key.replace(".", "_")
25
- return f'{sd_module_key}.lora_{network_part}'
26
 
27
- file_path = # replace with path to your own old motion LoRA checkpoint
28
- save_path = # replace with path to your own new motion LoRA checkpoint
29
- state_dict = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(file_path)
30
- modified_dict = {convert_mm_name_to_compvis(k): v for k, v in state_dict.items()}
31
- safetensors.torch.save_file(modified_dict, save_path)
32
- ```
33
 
 
34
 
35
- ## Sparse ControlNet
36
- Put Sparse ControlNets to `stable-diffusion-webui/models/ControlNet` and use Sparse ControlNets like any other ControlNets you use.
37
-
38
- Like Motion LoRA, I also converted state dict keys inside sparse ControlNet. Run the following script to make your own sparse ControlNet checkpoint compatible with AnimateDiff-A1111.
39
- ```python
40
- import torch
41
- import safetensors.torch
42
- ad_cn_old = "v3_sd15_sparsectrl_scribble.ckpt" # replace with path to your own old sparse ControlNet checkpoint
43
- ad_cn_new = "mm_sd15_v3_sparsectrl_scribble.safetensors" # replace with path to your own new sparse ControlNet checkpoint
44
- normal_cn_path = "diffusion_pytorch_model.fp16.safetensors" # download https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/diffusion_pytorch_model.fp16.safetensors?download=true and replace with the path to this model
45
- ad_cn = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(ad_cn_old)
46
- normal_cn = safetensors.torch.load_file(normal_cn_path)
47
- ad_cn_l, ad_cn_m = {}, {}
48
- for k in ad_cn.keys():
49
- if k.startswith("controlnet_cond_embedding"):
50
- new_key = k.replace("controlnet_cond_embedding.", "input_hint_block.0.")
51
- ad_cn_m[new_key] = ad_cn[k].to(torch.float16)
52
- elif not k in normal_cn:
53
- if "motion_modules" in k:
54
- ad_cn_m[k] = ad_cn[k].to(torch.float16)
55
- else:
56
- raise Exception(f"{k} not in normal_cn")
57
- else:
58
- ad_cn_l[k] = ad_cn[k].to(torch.float16)
59
- ad_cn_l = convert_from_diffuser_state_dict(ad_cn_l)
60
- ad_cn_l.update(ad_cn_m)
61
- safetensors.torch.save_file(ad_cn_l, ad_cn_new)
62
- ```
 
1
  # AnimateDiff Model Checkpoints for A1111 SD WebUI
2
  This repository saves all AnimateDiff models in fp16 & safetensors format for A1111 AnimateDiff users, including
3
  - motion module (v1-v3)
4
+ - [motion LoRA](#motion-lora) (v2 only, use like any other LoRA)
5
+ - domain adapter (v3 only, use like any other LoRA)
6
+ - [sparse ControlNet](#sparse-controlnet) (v3 only, use like any other ControlNet)
7
 
8
  Unless specified below, you are fine to use models from the [official model repository](https://huggingface.co/guoyww/animatediff/tree/main). I will only convert state dict keys if absolutely necessary.
9
 
10
 
11
  ## Motion LoRA
12
+ Put Motion LoRAs to `stable-diffusion-webui/models/Lora` and use Motion LoRAs like any other LoRA you use.
13
 
14
+ [lora_v2](lora_v2) contains motion LoRAs for AnimateDiff-A1111 v2.0.0. I converted state dict keys inside motion LoRAs. Originlal motion LoRAs won't work for AnimateDiff-A1111 v2.0.0 and later due to maintenance reason.
15
 
16
+ Use [convert.py](convert.py) in the following way if you want to convert a third-party motion LoRA to be compatible with A1111:
17
+ - Activate your A1111 Python environment first.
18
+ - Command: `python script.py lora [file_path] [save_path]`
19
+ - Replace `[file_path]` with the path to your old LoRA checkpoint.
20
+ - Replace `[save_path]` with the path where you want to save the new LoRA checkpoint.
21
 
 
 
 
 
 
22
 
23
+ ## Sparse ControlNet
24
+ Put Sparse ControlNets to `stable-diffusion-webui/models/ControlNet` and use Sparse ControlNets like any other ControlNet you use.
 
 
 
 
25
 
26
+ Like Motion LoRA, I converted state dict keys inside sparse ControlNet. Original sparse ControlNets won't work for A1111 due to maintenance reason.
27
 
28
+ Use [convert.py](convert.py) in the following way if you want to convert a third-party sparse ControlNet to be compatible with A1111:
29
+ - Activate your A1111 Python environment first.
30
+ - Command: `python script.py controlnet [ad_cn_old] [ad_cn_new] [normal_cn_path]`
31
+ - Replace `[ad_cn_old]` with the path to your old sparse ControlNet checkpoint.
32
+ - Replace `[ad_cn_new]` with the path where you want to save the new sparse ControlNet checkpoint.
33
+ - Replace `[normal_cn_path]` with the path to the normal ControlNet model. Download normal ControlNet from [here](https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/diffusion_pytorch_model.fp16.safetensors?download=true).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convert.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+ import torch
4
+ import safetensors.torch
5
+
6
+
7
+ def convert_mm_name_to_compvis(key):
8
+ sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
9
+ sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
10
+ sd_module_key = sd_module_key.replace(".", "_")
11
+ return f'{sd_module_key}.lora_{network_part}'
12
+
13
+ def convert_from_diffuser_state_dict(ad_cn_l):
14
+ unet_conversion_map = [
15
+ # (stable-diffusion, HF Diffusers)
16
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
17
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
18
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
19
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
20
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
21
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
22
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
23
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
24
+ ("input_blocks.0.0.weight", "conv_in.weight"),
25
+ ("input_blocks.0.0.bias", "conv_in.bias"),
26
+ ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
27
+ ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
28
+ ]
29
+
30
+ unet_conversion_map_resnet = [
31
+ # (stable-diffusion, HF Diffusers)
32
+ ("in_layers.0", "norm1"),
33
+ ("in_layers.2", "conv1"),
34
+ ("out_layers.0", "norm2"),
35
+ ("out_layers.3", "conv2"),
36
+ ("emb_layers.1", "time_emb_proj"),
37
+ ("skip_connection", "conv_shortcut"),
38
+ ]
39
+
40
+ unet_conversion_map_layer = []
41
+ # hardcoded number of downblocks and resnets/attentions...
42
+ # would need smarter logic for other networks.
43
+ for i in range(4):
44
+ # loop over downblocks/upblocks
45
+
46
+ for j in range(10):
47
+ # loop over resnets/attentions for downblocks
48
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
+
52
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
53
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
54
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
55
+
56
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
57
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
58
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
59
+
60
+
61
+ hf_mid_atn_prefix = "mid_block.attentions.0."
62
+ sd_mid_atn_prefix = "middle_block.1."
63
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
64
+
65
+ for j in range(2):
66
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
67
+ sd_mid_res_prefix = f"middle_block.{2*j}."
68
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
69
+
70
+ # controlnet specific
71
+
72
+ controlnet_cond_embedding_names = ['conv_in'] + [f'blocks.{i}' for i in range(6)] + ['conv_out']
73
+ for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
74
+ hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
75
+ sd_prefix = f"input_hint_block.{i*2}."
76
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
77
+
78
+ for i in range(12):
79
+ hf_prefix = f"controlnet_down_blocks.{i}."
80
+ sd_prefix = f"zero_convs.{i}.0."
81
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
82
+
83
+
84
+ def _convert_from_diffuser_state_dict(unet_state_dict):
85
+ mapping = {k: k for k in unet_state_dict.keys()}
86
+ for sd_name, hf_name in unet_conversion_map:
87
+ mapping[hf_name] = sd_name
88
+ for k, v in mapping.items():
89
+ if "resnets" in k:
90
+ for sd_part, hf_part in unet_conversion_map_resnet:
91
+ v = v.replace(hf_part, sd_part)
92
+ mapping[k] = v
93
+ for k, v in mapping.items():
94
+ for sd_part, hf_part in unet_conversion_map_layer:
95
+ v = v.replace(hf_part, sd_part)
96
+ mapping[k] = v
97
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items() if k in unet_state_dict}
98
+ return new_state_dict
99
+
100
+ return _convert_from_diffuser_state_dict(ad_cn_l)
101
+
102
+
103
+ def lora_conversion(file_path, save_path):
104
+ state_dict = safetensors.torch.load_file(file_path) if file_path.endswith(".safetensors") else torch.load(file_path)
105
+ modified_dict = {convert_mm_name_to_compvis(k): v for k, v in state_dict.items()}
106
+ safetensors.torch.save_file(modified_dict, save_path)
107
+ print(f"LoRA conversion completed: {save_path}")
108
+
109
+
110
+ def controlnet_conversion(ad_cn_old, ad_cn_new, normal_cn_path):
111
+ ad_cn = safetensors.torch.load_file(ad_cn_old) if ad_cn_old.endswith(".safetensors") else torch.load(ad_cn_old)
112
+ normal_cn = safetensors.torch.load_file(normal_cn_path)
113
+ ad_cn_l, ad_cn_m = {}, {}
114
+
115
+ for k in ad_cn.keys():
116
+ if k.startswith("controlnet_cond_embedding"):
117
+ new_key = k.replace("controlnet_cond_embedding.", "input_hint_block.0.")
118
+ ad_cn_m[new_key] = ad_cn[k].to(torch.float16)
119
+ elif not k in normal_cn:
120
+ if "motion_modules" in k:
121
+ ad_cn_m[k] = ad_cn[k].to(torch.float16)
122
+ else:
123
+ raise Exception(f"{k} not in normal_cn")
124
+ else:
125
+ ad_cn_l[k] = ad_cn[k].to(torch.float16)
126
+
127
+ ad_cn_l = convert_from_diffuser_state_dict(ad_cn_l)
128
+ ad_cn_l.update(ad_cn_m)
129
+ safetensors.torch.save_file(ad_cn_l, ad_cn_new)
130
+ print(f"ControlNet conversion completed: {ad_cn_new}")
131
+
132
+
133
+ def main():
134
+ parser = argparse.ArgumentParser(description="Script to convert LoRA and ControlNet models.")
135
+ subparsers = parser.add_subparsers(dest='command')
136
+
137
+ # LoRA conversion parser
138
+ lora_parser = subparsers.add_parser('lora', help='LoRA conversion')
139
+ lora_parser.add_argument('file_path', type=str, help='Path to the old LoRA checkpoint')
140
+ lora_parser.add_argument('save_path', type=str, help='Path to save the new LoRA checkpoint')
141
+
142
+ # ControlNet conversion parser
143
+ cn_parser = subparsers.add_parser('controlnet', help='ControlNet conversion')
144
+ cn_parser.add_argument('ad_cn_old', type=str, help='Path to the old sparse ControlNet checkpoint')
145
+ cn_parser.add_argument('ad_cn_new', type=str, help='Path to save the new sparse ControlNet checkpoint')
146
+ cn_parser.add_argument('normal_cn_path', type=str, help='Path to the normal ControlNet model')
147
+
148
+ args = parser.parse_args()
149
+
150
+ if args.command == 'lora':
151
+ lora_conversion(args.file_path, args.save_path)
152
+ elif args.command == 'controlnet':
153
+ controlnet_conversion(args.ad_cn_old, args.ad_cn_new, args.normal_cn_path)
154
+ else:
155
+ parser.print_help()
156
+
157
+ if __name__ == "__main__":
158
+ main()