John6666 commited on
Commit
07b4e15
·
verified ·
1 Parent(s): ca0ed41

Upload 2 files

Browse files
convert_repo_to_safetensors_sd.py CHANGED
@@ -227,15 +227,12 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
227
  text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228
 
229
  # Put together new checkpoint
 
 
 
230
  if input_safetensors:
231
- state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
232
- if half:
233
- state_dict = {k:v.half() for k,v in state_dict.items()}
234
- save_file(state_dict, checkpoint_path, metadata={"format": "pt"})
235
  else:
236
- state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
237
- if half:
238
- state_dict = {k:v.half() for k,v in state_dict.items()}
239
  state_dict = {"state_dict": state_dict}
240
  torch.save(state_dict, checkpoint_path)
241
 
 
227
  text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228
 
229
  # Put together new checkpoint
230
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231
+ if half:
232
+ state_dict = {k:v.half() for k,v in state_dict.items()}
233
  if input_safetensors:
234
+ save_file(state_dict, checkpoint_path)
 
 
 
235
  else:
 
 
 
236
  state_dict = {"state_dict": state_dict}
237
  torch.save(state_dict, checkpoint_path)
238
 
convert_repo_to_safetensors_sd_gr.py CHANGED
@@ -228,15 +228,12 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, p
228
  text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
229
 
230
  # Put together new checkpoint
 
 
 
231
  if input_safetensors:
232
- state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
233
- if half:
234
- state_dict = {k:v.half() for k,v in state_dict.items()}
235
- save_file(state_dict, checkpoint_path, metadata={"format": "pt"})
236
  else:
237
- state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
238
- if half:
239
- state_dict = {k:v.half() for k,v in state_dict.items()}
240
  state_dict = {"state_dict": state_dict}
241
  torch.save(state_dict, checkpoint_path)
242
 
 
228
  text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
229
 
230
  # Put together new checkpoint
231
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
232
+ if half:
233
+ state_dict = {k:v.half() for k,v in state_dict.items()}
234
  if input_safetensors:
235
+ save_file(state_dict, checkpoint_path)
 
 
 
236
  else:
 
 
 
237
  state_dict = {"state_dict": state_dict}
238
  torch.save(state_dict, checkpoint_path)
239