import argparse from pathlib import Path import torch from safetensors.torch import save_file def convert(path: Path, half: bool = False, no_ema: bool = False): state_dict = torch.load(path, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] to_remove = [] for k, v in state_dict.items(): if not isinstance(v, torch.Tensor): to_remove.append(k) elif no_ema and "ema" in k: to_remove.append(k) for k in to_remove: del state_dict[k] if half: state_dict = {k: v.half() for k, v in state_dict.items()} output_name = path.stem if no_ema: output_name += "-pruned" if half: output_name += "-fp16" output_path = path.parent / f"{output_name}.safetensors" save_file(state_dict, output_path.as_posix()) def main(path: str, half: bool = False, no_ema: bool = False): path_ = Path(path).resolve() if not path_.exists(): raise ValueError(f"Invalid path: {path}") if path_.is_file(): to_convert = [path_] else: to_convert = list(path_.glob("*.ckpt")) for file in to_convert: print(f"Converting... {file}") convert(file, half, no_ema) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("path", type=str, help="Path to checkpoint file or directory.") parser.add_argument( "--half", action="store_true", help="Convert to half precision." ) parser.add_argument("--no-ema", action="store_true", help="Ignore EMA weights.") return parser.parse_args() if __name__ == "__main__": args = parse_args() main(args.path, args.half, args.no_ema)