John6666 commited on
Commit
7c177b4
·
verified ·
1 Parent(s): 07b4e15

Upload 2 files

Browse files
convert_repo_to_safetensors_sd.py CHANGED
@@ -1,12 +1,13 @@
1
  # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
  # *Only* converts the UNet, VAE, and Text Encoder.
3
  # Does not convert optimizer state or any other thing.
4
- # Written by jachiam
5
 
6
  import argparse
7
  import os.path as osp
 
8
 
9
  import torch
 
10
 
11
 
12
  # =================#
@@ -158,10 +159,21 @@ vae_conversion_map_attn = [
158
  ("proj_out.", "proj_attn."),
159
  ]
160
 
 
 
 
 
 
 
 
 
161
 
162
  def reshape_weight_for_sd(w):
163
  # convert HF linear weights to SD conv2d weights
164
- return w.reshape(*w.shape, 1, 1)
 
 
 
165
 
166
 
167
  def convert_vae_state_dict(vae_state_dict):
@@ -177,18 +189,92 @@ def convert_vae_state_dict(vae_state_dict):
177
  mapping[k] = v
178
  new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
  weights_to_convert = ["q", "k", "v", "proj_out"]
 
180
  for k, v in new_state_dict.items():
181
  for weight_name in weights_to_convert:
182
  if f"mid.attn_1.{weight_name}.weight" in k:
183
  print(f"Reshaping {k} for SD format")
184
  new_state_dict[k] = reshape_weight_for_sd(v)
 
 
 
 
 
 
 
 
185
  return new_state_dict
186
 
187
 
188
  # =========================#
189
  # Text Encoder Conversion #
190
  # =========================#
191
- # pretty much a no-op
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
  def convert_text_enc_state_dict(text_enc_dict):
@@ -196,45 +282,56 @@ def convert_text_enc_state_dict(text_enc_dict):
196
 
197
 
198
  def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
199
- from safetensors.torch import load_file, save_file
200
- input_safetensors = False
201
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
202
- if not osp.exists(unet_path):
203
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
204
- input_safetensors = True
205
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
206
- if not osp.exists(vae_path):
207
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
208
- input_safetensors = True
209
- text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
210
- if not osp.exists(text_enc_path):
211
- text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
212
- input_safetensors = True
 
 
 
 
 
 
 
 
 
213
 
214
  # Convert the UNet model
215
- unet_state_dict = torch.load(unet_path, map_location='cpu') if not input_safetensors else load_file(unet_path, device='cpu')
216
  unet_state_dict = convert_unet_state_dict(unet_state_dict)
217
  unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218
 
219
  # Convert the VAE model
220
- vae_state_dict = torch.load(vae_path, map_location='cpu') if not input_safetensors else load_file(vae_path, device='cpu')
221
  vae_state_dict = convert_vae_state_dict(vae_state_dict)
222
  vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223
 
224
- # Convert the text encoder model
225
- text_enc_dict = torch.load(text_enc_path, map_location='cpu') if not input_safetensors else load_file(text_enc_path, device='cpu')
226
- text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227
- text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
 
 
 
 
 
 
 
228
 
229
  # Put together new checkpoint
230
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231
  if half:
232
- state_dict = {k:v.half() for k,v in state_dict.items()}
233
- if input_safetensors:
234
- save_file(state_dict, checkpoint_path)
235
- else:
236
- state_dict = {"state_dict": state_dict}
237
- torch.save(state_dict, checkpoint_path)
238
 
239
 
240
  def download_repo(repo_id, dir_path):
@@ -258,7 +355,7 @@ if __name__ == "__main__":
258
  parser = argparse.ArgumentParser()
259
 
260
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
261
- parser.add_argument("--half", action="store_true", default=True, help="Save weights in half precision.")
262
 
263
  args = parser.parse_args()
264
  assert args.repo_id is not None, "Must provide a Repo ID!"
 
1
  # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
  # *Only* converts the UNet, VAE, and Text Encoder.
3
  # Does not convert optimizer state or any other thing.
 
4
 
5
  import argparse
6
  import os.path as osp
7
+ import re
8
 
9
  import torch
10
+ from safetensors.torch import load_file, save_file
11
 
12
 
13
  # =================#
 
159
  ("proj_out.", "proj_attn."),
160
  ]
161
 
162
+ # This is probably not the most ideal solution, but it does work.
163
+ vae_extra_conversion_map = [
164
+ ("to_q", "q"),
165
+ ("to_k", "k"),
166
+ ("to_v", "v"),
167
+ ("to_out.0", "proj_out"),
168
+ ]
169
+
170
 
171
  def reshape_weight_for_sd(w):
172
  # convert HF linear weights to SD conv2d weights
173
+ if not w.ndim == 1:
174
+ return w.reshape(*w.shape, 1, 1)
175
+ else:
176
+ return w
177
 
178
 
179
  def convert_vae_state_dict(vae_state_dict):
 
189
  mapping[k] = v
190
  new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
191
  weights_to_convert = ["q", "k", "v", "proj_out"]
192
+ keys_to_rename = {}
193
  for k, v in new_state_dict.items():
194
  for weight_name in weights_to_convert:
195
  if f"mid.attn_1.{weight_name}.weight" in k:
196
  print(f"Reshaping {k} for SD format")
197
  new_state_dict[k] = reshape_weight_for_sd(v)
198
+ for weight_name, real_weight_name in vae_extra_conversion_map:
199
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
200
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
201
+ for k, v in keys_to_rename.items():
202
+ if k in new_state_dict:
203
+ print(f"Renaming {k} to {v}")
204
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
205
+ del new_state_dict[k]
206
  return new_state_dict
207
 
208
 
209
  # =========================#
210
  # Text Encoder Conversion #
211
  # =========================#
212
+
213
+
214
+ textenc_conversion_lst = [
215
+ # (stable-diffusion, HF Diffusers)
216
+ ("resblocks.", "text_model.encoder.layers."),
217
+ ("ln_1", "layer_norm1"),
218
+ ("ln_2", "layer_norm2"),
219
+ (".c_fc.", ".fc1."),
220
+ (".c_proj.", ".fc2."),
221
+ (".attn", ".self_attn"),
222
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
223
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
224
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
225
+ ]
226
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
227
+ textenc_pattern = re.compile("|".join(protected.keys()))
228
+
229
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
230
+ code2idx = {"q": 0, "k": 1, "v": 2}
231
+
232
+
233
+ def convert_text_enc_state_dict_v20(text_enc_dict):
234
+ new_state_dict = {}
235
+ capture_qkv_weight = {}
236
+ capture_qkv_bias = {}
237
+ for k, v in text_enc_dict.items():
238
+ if (
239
+ k.endswith(".self_attn.q_proj.weight")
240
+ or k.endswith(".self_attn.k_proj.weight")
241
+ or k.endswith(".self_attn.v_proj.weight")
242
+ ):
243
+ k_pre = k[: -len(".q_proj.weight")]
244
+ k_code = k[-len("q_proj.weight")]
245
+ if k_pre not in capture_qkv_weight:
246
+ capture_qkv_weight[k_pre] = [None, None, None]
247
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
248
+ continue
249
+
250
+ if (
251
+ k.endswith(".self_attn.q_proj.bias")
252
+ or k.endswith(".self_attn.k_proj.bias")
253
+ or k.endswith(".self_attn.v_proj.bias")
254
+ ):
255
+ k_pre = k[: -len(".q_proj.bias")]
256
+ k_code = k[-len("q_proj.bias")]
257
+ if k_pre not in capture_qkv_bias:
258
+ capture_qkv_bias[k_pre] = [None, None, None]
259
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
260
+ continue
261
+
262
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
263
+ new_state_dict[relabelled_key] = v
264
+
265
+ for k_pre, tensors in capture_qkv_weight.items():
266
+ if None in tensors:
267
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
268
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
269
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
270
+
271
+ for k_pre, tensors in capture_qkv_bias.items():
272
+ if None in tensors:
273
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
274
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
275
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
276
+
277
+ return new_state_dict
278
 
279
 
280
  def convert_text_enc_state_dict(text_enc_dict):
 
282
 
283
 
284
  def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
285
+ # Path for safetensors
286
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
287
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
288
+ text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
289
+
290
+ # Load models from safetensors if it exists, if it doesn't pytorch
291
+ if osp.exists(unet_path):
292
+ unet_state_dict = load_file(unet_path, device="cpu")
293
+ else:
294
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
295
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
296
+
297
+ if osp.exists(vae_path):
298
+ vae_state_dict = load_file(vae_path, device="cpu")
299
+ else:
300
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
301
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
302
+
303
+ if osp.exists(text_enc_path):
304
+ text_enc_dict = load_file(text_enc_path, device="cpu")
305
+ else:
306
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
307
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
308
 
309
  # Convert the UNet model
 
310
  unet_state_dict = convert_unet_state_dict(unet_state_dict)
311
  unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
312
 
313
  # Convert the VAE model
 
314
  vae_state_dict = convert_vae_state_dict(vae_state_dict)
315
  vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
316
 
317
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
318
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
319
+
320
+ if is_v20_model:
321
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
322
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
323
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
324
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
325
+ else:
326
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
327
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
328
 
329
  # Put together new checkpoint
330
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
331
  if half:
332
+ state_dict = {k: v.half() for k, v in state_dict.items()}
333
+
334
+ save_file(state_dict, checkpoint_path)
 
 
 
335
 
336
 
337
  def download_repo(repo_id, dir_path):
 
355
  parser = argparse.ArgumentParser()
356
 
357
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
358
+ parser.add_argument("--half", default=True, help="Save weights in half precision.")
359
 
360
  args = parser.parse_args()
361
  assert args.repo_id is not None, "Must provide a Repo ID!"
convert_repo_to_safetensors_sd_gr.py CHANGED
@@ -1,14 +1,16 @@
1
  # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
  # *Only* converts the UNet, VAE, and Text Encoder.
3
  # Does not convert optimizer state or any other thing.
4
- # Written by jachiam
5
 
6
  import argparse
7
  import os.path as osp
 
8
 
9
  import torch
 
10
  import gradio as gr
11
 
 
12
  # =================#
13
  # UNet Conversion #
14
  # =================#
@@ -158,10 +160,21 @@ vae_conversion_map_attn = [
158
  ("proj_out.", "proj_attn."),
159
  ]
160
 
 
 
 
 
 
 
 
 
161
 
162
  def reshape_weight_for_sd(w):
163
  # convert HF linear weights to SD conv2d weights
164
- return w.reshape(*w.shape, 1, 1)
 
 
 
165
 
166
 
167
  def convert_vae_state_dict(vae_state_dict):
@@ -177,18 +190,92 @@ def convert_vae_state_dict(vae_state_dict):
177
  mapping[k] = v
178
  new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
  weights_to_convert = ["q", "k", "v", "proj_out"]
 
180
  for k, v in new_state_dict.items():
181
  for weight_name in weights_to_convert:
182
  if f"mid.attn_1.{weight_name}.weight" in k:
183
  print(f"Reshaping {k} for SD format")
184
  new_state_dict[k] = reshape_weight_for_sd(v)
 
 
 
 
 
 
 
 
185
  return new_state_dict
186
 
187
 
188
  # =========================#
189
  # Text Encoder Conversion #
190
  # =========================#
191
- # pretty much a no-op
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
  def convert_text_enc_state_dict(text_enc_dict):
@@ -197,45 +284,56 @@ def convert_text_enc_state_dict(text_enc_dict):
197
 
198
  def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, progress=gr.Progress(track_tqdm=True)):
199
  progress(0, desc="Start converting...")
200
- from safetensors.torch import load_file, save_file
201
- input_safetensors = False
202
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
203
- if not osp.exists(unet_path):
204
- unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
205
- input_safetensors = True
206
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
207
- if not osp.exists(vae_path):
208
- vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
209
- input_safetensors = True
210
- text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
211
- if not osp.exists(text_enc_path):
212
- text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
213
- input_safetensors = True
 
 
 
 
 
 
 
 
 
214
 
215
  # Convert the UNet model
216
- unet_state_dict = torch.load(unet_path, map_location='cpu') if not input_safetensors else load_file(unet_path)
217
  unet_state_dict = convert_unet_state_dict(unet_state_dict)
218
  unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
219
 
220
  # Convert the VAE model
221
- vae_state_dict = torch.load(vae_path, map_location='cpu') if not input_safetensors else load_file(vae_path)
222
  vae_state_dict = convert_vae_state_dict(vae_state_dict)
223
  vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
224
 
225
- # Convert the text encoder model
226
- text_enc_dict = torch.load(text_enc_path, map_location='cpu') if not input_safetensors else load_file(text_enc_path)
227
- text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
228
- text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
 
 
 
 
 
 
 
229
 
230
  # Put together new checkpoint
231
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
232
  if half:
233
- state_dict = {k:v.half() for k,v in state_dict.items()}
234
- if input_safetensors:
235
- save_file(state_dict, checkpoint_path)
236
- else:
237
- state_dict = {"state_dict": state_dict}
238
- torch.save(state_dict, checkpoint_path)
239
 
240
  progress(1, desc="Converted.")
241
 
@@ -295,7 +393,7 @@ if __name__ == "__main__":
295
  parser = argparse.ArgumentParser()
296
 
297
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
298
- parser.add_argument("--half", action="store_true", default=True, help="Save weights in half precision.")
299
 
300
  args = parser.parse_args()
301
  assert args.repo_id is not None, "Must provide a Repo ID!"
 
1
  # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
  # *Only* converts the UNet, VAE, and Text Encoder.
3
  # Does not convert optimizer state or any other thing.
 
4
 
5
  import argparse
6
  import os.path as osp
7
+ import re
8
 
9
  import torch
10
+ from safetensors.torch import load_file, save_file
11
  import gradio as gr
12
 
13
+
14
  # =================#
15
  # UNet Conversion #
16
  # =================#
 
160
  ("proj_out.", "proj_attn."),
161
  ]
162
 
163
+ # This is probably not the most ideal solution, but it does work.
164
+ vae_extra_conversion_map = [
165
+ ("to_q", "q"),
166
+ ("to_k", "k"),
167
+ ("to_v", "v"),
168
+ ("to_out.0", "proj_out"),
169
+ ]
170
+
171
 
172
  def reshape_weight_for_sd(w):
173
  # convert HF linear weights to SD conv2d weights
174
+ if not w.ndim == 1:
175
+ return w.reshape(*w.shape, 1, 1)
176
+ else:
177
+ return w
178
 
179
 
180
  def convert_vae_state_dict(vae_state_dict):
 
190
  mapping[k] = v
191
  new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
192
  weights_to_convert = ["q", "k", "v", "proj_out"]
193
+ keys_to_rename = {}
194
  for k, v in new_state_dict.items():
195
  for weight_name in weights_to_convert:
196
  if f"mid.attn_1.{weight_name}.weight" in k:
197
  print(f"Reshaping {k} for SD format")
198
  new_state_dict[k] = reshape_weight_for_sd(v)
199
+ for weight_name, real_weight_name in vae_extra_conversion_map:
200
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
201
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
202
+ for k, v in keys_to_rename.items():
203
+ if k in new_state_dict:
204
+ print(f"Renaming {k} to {v}")
205
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
206
+ del new_state_dict[k]
207
  return new_state_dict
208
 
209
 
210
  # =========================#
211
  # Text Encoder Conversion #
212
  # =========================#
213
+
214
+
215
+ textenc_conversion_lst = [
216
+ # (stable-diffusion, HF Diffusers)
217
+ ("resblocks.", "text_model.encoder.layers."),
218
+ ("ln_1", "layer_norm1"),
219
+ ("ln_2", "layer_norm2"),
220
+ (".c_fc.", ".fc1."),
221
+ (".c_proj.", ".fc2."),
222
+ (".attn", ".self_attn"),
223
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
224
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
225
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
226
+ ]
227
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
228
+ textenc_pattern = re.compile("|".join(protected.keys()))
229
+
230
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
231
+ code2idx = {"q": 0, "k": 1, "v": 2}
232
+
233
+
234
+ def convert_text_enc_state_dict_v20(text_enc_dict):
235
+ new_state_dict = {}
236
+ capture_qkv_weight = {}
237
+ capture_qkv_bias = {}
238
+ for k, v in text_enc_dict.items():
239
+ if (
240
+ k.endswith(".self_attn.q_proj.weight")
241
+ or k.endswith(".self_attn.k_proj.weight")
242
+ or k.endswith(".self_attn.v_proj.weight")
243
+ ):
244
+ k_pre = k[: -len(".q_proj.weight")]
245
+ k_code = k[-len("q_proj.weight")]
246
+ if k_pre not in capture_qkv_weight:
247
+ capture_qkv_weight[k_pre] = [None, None, None]
248
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
249
+ continue
250
+
251
+ if (
252
+ k.endswith(".self_attn.q_proj.bias")
253
+ or k.endswith(".self_attn.k_proj.bias")
254
+ or k.endswith(".self_attn.v_proj.bias")
255
+ ):
256
+ k_pre = k[: -len(".q_proj.bias")]
257
+ k_code = k[-len("q_proj.bias")]
258
+ if k_pre not in capture_qkv_bias:
259
+ capture_qkv_bias[k_pre] = [None, None, None]
260
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
261
+ continue
262
+
263
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
264
+ new_state_dict[relabelled_key] = v
265
+
266
+ for k_pre, tensors in capture_qkv_weight.items():
267
+ if None in tensors:
268
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
269
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
270
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
271
+
272
+ for k_pre, tensors in capture_qkv_bias.items():
273
+ if None in tensors:
274
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
275
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
276
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
277
+
278
+ return new_state_dict
279
 
280
 
281
  def convert_text_enc_state_dict(text_enc_dict):
 
284
 
285
  def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, progress=gr.Progress(track_tqdm=True)):
286
  progress(0, desc="Start converting...")
287
+ # Path for safetensors
288
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
289
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
290
+ text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
291
+
292
+ # Load models from safetensors if it exists, if it doesn't pytorch
293
+ if osp.exists(unet_path):
294
+ unet_state_dict = load_file(unet_path, device="cpu")
295
+ else:
296
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
297
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
298
+
299
+ if osp.exists(vae_path):
300
+ vae_state_dict = load_file(vae_path, device="cpu")
301
+ else:
302
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
303
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
304
+
305
+ if osp.exists(text_enc_path):
306
+ text_enc_dict = load_file(text_enc_path, device="cpu")
307
+ else:
308
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
309
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
310
 
311
  # Convert the UNet model
 
312
  unet_state_dict = convert_unet_state_dict(unet_state_dict)
313
  unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
314
 
315
  # Convert the VAE model
 
316
  vae_state_dict = convert_vae_state_dict(vae_state_dict)
317
  vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
318
 
319
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
320
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
321
+
322
+ if is_v20_model:
323
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
324
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
325
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
326
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
327
+ else:
328
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
329
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
330
 
331
  # Put together new checkpoint
332
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
333
  if half:
334
+ state_dict = {k: v.half() for k, v in state_dict.items()}
335
+
336
+ save_file(state_dict, checkpoint_path)
 
 
 
337
 
338
  progress(1, desc="Converted.")
339
 
 
393
  parser = argparse.ArgumentParser()
394
 
395
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
396
+ parser.add_argument("--half", default=True, help="Save weights in half precision.")
397
 
398
  args = parser.parse_args()
399
  assert args.repo_id is not None, "Must provide a Repo ID!"