Upload convert_repo_to_safetensors_sd_gr.py
Browse files
convert_repo_to_safetensors_sd_gr.py
CHANGED
@@ -339,6 +339,7 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16",
|
|
339 |
if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
|
340 |
elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
|
341 |
elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
|
|
|
342 |
|
343 |
save_file(state_dict, checkpoint_path)
|
344 |
|
@@ -417,7 +418,7 @@ if __name__ == "__main__":
|
|
417 |
parser = argparse.ArgumentParser()
|
418 |
|
419 |
parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
|
420 |
-
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
|
421 |
|
422 |
args = parser.parse_args()
|
423 |
assert args.repo_id is not None, "Must provide a Repo ID!"
|
|
|
339 |
if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
|
340 |
elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
|
341 |
elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
|
342 |
+
elif dtype == "fp8": state_dict = {k: v.to(torch.float8_e4m3fn) for k, v in state_dict.items()}
|
343 |
|
344 |
save_file(state_dict, checkpoint_path)
|
345 |
|
|
|
418 |
parser = argparse.ArgumentParser()
|
419 |
|
420 |
parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
|
421 |
+
parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "fp8", "default"], help='Output data type. (Default: "fp16")')
|
422 |
|
423 |
args = parser.parse_args()
|
424 |
assert args.repo_id is not None, "Must provide a Repo ID!"
|