etx commited on
Commit
860e0da
1 Parent(s): 8dd8da0
Files changed (2) hide show
  1. convert_diffusers_to_sd.py +234 -0
  2. model-1200.ckpt +3 -0
convert_diffusers_to_sd.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+
8
+ import torch
9
+
10
+
11
+ # =================#
12
+ # UNet Conversion #
13
+ # =================#
14
+
15
+ unet_conversion_map = [
16
+ # (stable-diffusion, HF Diffusers)
17
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
18
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
19
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
20
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
21
+ ("input_blocks.0.0.weight", "conv_in.weight"),
22
+ ("input_blocks.0.0.bias", "conv_in.bias"),
23
+ ("out.0.weight", "conv_norm_out.weight"),
24
+ ("out.0.bias", "conv_norm_out.bias"),
25
+ ("out.2.weight", "conv_out.weight"),
26
+ ("out.2.bias", "conv_out.bias"),
27
+ ]
28
+
29
+ unet_conversion_map_resnet = [
30
+ # (stable-diffusion, HF Diffusers)
31
+ ("in_layers.0", "norm1"),
32
+ ("in_layers.2", "conv1"),
33
+ ("out_layers.0", "norm2"),
34
+ ("out_layers.3", "conv2"),
35
+ ("emb_layers.1", "time_emb_proj"),
36
+ ("skip_connection", "conv_shortcut"),
37
+ ]
38
+
39
+ unet_conversion_map_layer = []
40
+ # hardcoded number of downblocks and resnets/attentions...
41
+ # would need smarter logic for other networks.
42
+ for i in range(4):
43
+ # loop over downblocks/upblocks
44
+
45
+ for j in range(2):
46
+ # loop over resnets/attentions for downblocks
47
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
48
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
49
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
50
+
51
+ if i < 3:
52
+ # no attention layers in down_blocks.3
53
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
54
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
55
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
56
+
57
+ for j in range(3):
58
+ # loop over resnets/attentions for upblocks
59
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
60
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
61
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
62
+
63
+ if i > 0:
64
+ # no attention layers in up_blocks.0
65
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
66
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
67
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
68
+
69
+ if i < 3:
70
+ # no downsample in down_blocks.3
71
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
72
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
73
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
74
+
75
+ # no upsample in up_blocks.3
76
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
77
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
78
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
79
+
80
+ hf_mid_atn_prefix = "mid_block.attentions.0."
81
+ sd_mid_atn_prefix = "middle_block.1."
82
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
83
+
84
+ for j in range(2):
85
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
86
+ sd_mid_res_prefix = f"middle_block.{2*j}."
87
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
88
+
89
+
90
+ def convert_unet_state_dict(unet_state_dict):
91
+ # buyer beware: this is a *brittle* function,
92
+ # and correct output requires that all of these pieces interact in
93
+ # the exact order in which I have arranged them.
94
+ mapping = {k: k for k in unet_state_dict.keys()}
95
+ for sd_name, hf_name in unet_conversion_map:
96
+ mapping[hf_name] = sd_name
97
+ for k, v in mapping.items():
98
+ if "resnets" in k:
99
+ for sd_part, hf_part in unet_conversion_map_resnet:
100
+ v = v.replace(hf_part, sd_part)
101
+ mapping[k] = v
102
+ for k, v in mapping.items():
103
+ for sd_part, hf_part in unet_conversion_map_layer:
104
+ v = v.replace(hf_part, sd_part)
105
+ mapping[k] = v
106
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
107
+ return new_state_dict
108
+
109
+
110
+ # ================#
111
+ # VAE Conversion #
112
+ # ================#
113
+
114
+ vae_conversion_map = [
115
+ # (stable-diffusion, HF Diffusers)
116
+ ("nin_shortcut", "conv_shortcut"),
117
+ ("norm_out", "conv_norm_out"),
118
+ ("mid.attn_1.", "mid_block.attentions.0."),
119
+ ]
120
+
121
+ for i in range(4):
122
+ # down_blocks have two resnets
123
+ for j in range(2):
124
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
125
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
126
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
127
+
128
+ if i < 3:
129
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
130
+ sd_downsample_prefix = f"down.{i}.downsample."
131
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
132
+
133
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
134
+ sd_upsample_prefix = f"up.{3-i}.upsample."
135
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
136
+
137
+ # up_blocks have three resnets
138
+ # also, up blocks in hf are numbered in reverse from sd
139
+ for j in range(3):
140
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
141
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
142
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
143
+
144
+ # this part accounts for mid blocks in both the encoder and the decoder
145
+ for i in range(2):
146
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
147
+ sd_mid_res_prefix = f"mid.block_{i+1}."
148
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
149
+
150
+
151
+ vae_conversion_map_attn = [
152
+ # (stable-diffusion, HF Diffusers)
153
+ ("norm.", "group_norm."),
154
+ ("q.", "query."),
155
+ ("k.", "key."),
156
+ ("v.", "value."),
157
+ ("proj_out.", "proj_attn."),
158
+ ]
159
+
160
+
161
+ def reshape_weight_for_sd(w):
162
+ # convert HF linear weights to SD conv2d weights
163
+ return w.reshape(*w.shape, 1, 1)
164
+
165
+
166
+ def convert_vae_state_dict(vae_state_dict):
167
+ mapping = {k: k for k in vae_state_dict.keys()}
168
+ for k, v in mapping.items():
169
+ for sd_part, hf_part in vae_conversion_map:
170
+ v = v.replace(hf_part, sd_part)
171
+ mapping[k] = v
172
+ for k, v in mapping.items():
173
+ if "attentions" in k:
174
+ for sd_part, hf_part in vae_conversion_map_attn:
175
+ v = v.replace(hf_part, sd_part)
176
+ mapping[k] = v
177
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
178
+ weights_to_convert = ["q", "k", "v", "proj_out"]
179
+ for k, v in new_state_dict.items():
180
+ for weight_name in weights_to_convert:
181
+ if f"mid.attn_1.{weight_name}.weight" in k:
182
+ print(f"Reshaping {k} for SD format")
183
+ new_state_dict[k] = reshape_weight_for_sd(v)
184
+ return new_state_dict
185
+
186
+
187
+ # =========================#
188
+ # Text Encoder Conversion #
189
+ # =========================#
190
+ # pretty much a no-op
191
+
192
+
193
+ def convert_text_enc_state_dict(text_enc_dict):
194
+ return text_enc_dict
195
+
196
+
197
+ if __name__ == "__main__":
198
+ parser = argparse.ArgumentParser()
199
+
200
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
201
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
202
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
203
+
204
+ args = parser.parse_args()
205
+
206
+ assert args.model_path is not None, "Must provide a model path!"
207
+
208
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
209
+
210
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
211
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
212
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
213
+
214
+ # Convert the UNet model
215
+ unet_state_dict = torch.load(unet_path, map_location='cpu')
216
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
217
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218
+
219
+ # Convert the VAE model
220
+ vae_state_dict = torch.load(vae_path, map_location='cpu')
221
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
222
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223
+
224
+ # Convert the text encoder model
225
+ text_enc_dict = torch.load(text_enc_path, map_location='cpu')
226
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228
+
229
+ # Put together new checkpoint
230
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231
+ if args.half:
232
+ state_dict = {k:v.half() for k,v in state_dict.items()}
233
+ state_dict = {"state_dict": state_dict}
234
+ torch.save(state_dict, args.checkpoint_path)
model-1200.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:410fd6a7fe948619f9f4d7d8d2a8debd2da7ab9ffc5c5628ff52c6065b6ad2de
3
+ size 4265327726