Upload 2 files
Browse files
convert_repo_to_safetensors_sd.py
CHANGED
@@ -227,15 +227,12 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
|
|
227 |
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
228 |
|
229 |
# Put together new checkpoint
|
|
|
|
|
|
|
230 |
if input_safetensors:
|
231 |
-
state_dict
|
232 |
-
if half:
|
233 |
-
state_dict = {k:v.half() for k,v in state_dict.items()}
|
234 |
-
save_file(state_dict, checkpoint_path, metadata={"format": "pt"})
|
235 |
else:
|
236 |
-
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
237 |
-
if half:
|
238 |
-
state_dict = {k:v.half() for k,v in state_dict.items()}
|
239 |
state_dict = {"state_dict": state_dict}
|
240 |
torch.save(state_dict, checkpoint_path)
|
241 |
|
|
|
227 |
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
228 |
|
229 |
# Put together new checkpoint
|
230 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
231 |
+
if half:
|
232 |
+
state_dict = {k:v.half() for k,v in state_dict.items()}
|
233 |
if input_safetensors:
|
234 |
+
save_file(state_dict, checkpoint_path)
|
|
|
|
|
|
|
235 |
else:
|
|
|
|
|
|
|
236 |
state_dict = {"state_dict": state_dict}
|
237 |
torch.save(state_dict, checkpoint_path)
|
238 |
|
convert_repo_to_safetensors_sd_gr.py
CHANGED
@@ -228,15 +228,12 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, p
|
|
228 |
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
229 |
|
230 |
# Put together new checkpoint
|
|
|
|
|
|
|
231 |
if input_safetensors:
|
232 |
-
state_dict
|
233 |
-
if half:
|
234 |
-
state_dict = {k:v.half() for k,v in state_dict.items()}
|
235 |
-
save_file(state_dict, checkpoint_path, metadata={"format": "pt"})
|
236 |
else:
|
237 |
-
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
238 |
-
if half:
|
239 |
-
state_dict = {k:v.half() for k,v in state_dict.items()}
|
240 |
state_dict = {"state_dict": state_dict}
|
241 |
torch.save(state_dict, checkpoint_path)
|
242 |
|
|
|
228 |
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
229 |
|
230 |
# Put together new checkpoint
|
231 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
232 |
+
if half:
|
233 |
+
state_dict = {k:v.half() for k,v in state_dict.items()}
|
234 |
if input_safetensors:
|
235 |
+
save_file(state_dict, checkpoint_path)
|
|
|
|
|
|
|
236 |
else:
|
|
|
|
|
|
|
237 |
state_dict = {"state_dict": state_dict}
|
238 |
torch.save(state_dict, checkpoint_path)
|
239 |
|