| """Helper script to convert diffusion checkpoints to format used by image generator.""" | |
| import os | |
| from absl import app | |
| from absl import flags | |
| import requests | |
| import torch as th | |
| _CKPT_PATH = flags.DEFINE_string( | |
| "ckpt_path", default=None, help="Path to checkpoint file", required=True) | |
| _OUTPUT_PATH = flags.DEFINE_string( | |
| "output_path", default="bins", help="Output folder path", required=False) | |
| VOCAB_URL = "https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt" | |
| def run(ckpt_path, output_path): | |
| """Converts the checkpoint and saves the result. | |
| Args: | |
| ckpt_path: Source checkpoint path | |
| output_path: Result folder directory | |
| """ | |
| os.makedirs(output_path, exist_ok=True) | |
| ckpt = th.load(ckpt_path, map_location="cpu") | |
| vocab_dest = os.path.join(output_path, os.path.basename(VOCAB_URL)) | |
| if not os.path.exists(vocab_dest): | |
| with requests.get(VOCAB_URL, stream=True) as response: | |
| with open(vocab_dest, "wb") as vocab_file: | |
| for c in response.iter_content(chunk_size=8192): | |
| vocab_file.write(c) | |
| for k, v in ckpt["state_dict"].items(): | |
| if "first_stage_model.encoder" in k: | |
| continue | |
| if not hasattr(v, "numpy"): | |
| continue | |
| output_bin_file = os.path.join(output_path, f"{k}.bin") | |
| v.numpy().astype("float16").tofile(output_bin_file) | |
| def main(_) -> None: | |
| ckpt_path = _CKPT_PATH.value | |
| output_path = _OUTPUT_PATH.value | |
| run(ckpt_path, output_path) | |
| if __name__ == "__main__": | |
| app.run(main) | |