John6666 commited on
Commit
97f4f4d
1 Parent(s): c983a5f

Upload convert_repo_to_safetensors_sd_gr.py

Browse files
convert_repo_to_safetensors_sd_gr.py CHANGED
@@ -339,6 +339,7 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16",
339
  if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
340
  elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
341
  elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
 
342
 
343
  save_file(state_dict, checkpoint_path)
344
 
@@ -417,7 +418,7 @@ if __name__ == "__main__":
417
  parser = argparse.ArgumentParser()
418
 
419
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
420
- parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
421
 
422
  args = parser.parse_args()
423
  assert args.repo_id is not None, "Must provide a Repo ID!"
 
339
  if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
340
  elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
341
  elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
342
+ elif dtype == "fp8": state_dict = {k: v.to(torch.float8_e4m3fn) for k, v in state_dict.items()}
343
 
344
  save_file(state_dict, checkpoint_path)
345
 
 
418
  parser = argparse.ArgumentParser()
419
 
420
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
421
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "fp8", "default"], help='Output data type. (Default: "fp16")')
422
 
423
  args = parser.parse_args()
424
  assert args.repo_id is not None, "Must provide a Repo ID!"