Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import safetensors.torch | |
from diffusers import AutoencoderTiny | |
""" | |
Example - From the diffusers root directory: | |
Download the weights: | |
```sh | |
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors | |
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors | |
``` | |
Convert the model: | |
```sh | |
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \ | |
--encoder_ckpt_path taesd_encoder.safetensors \ | |
--decoder_ckpt_path taesd_decoder.safetensors \ | |
--dump_path taesd-diffusers | |
``` | |
""" | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | |
parser.add_argument( | |
"--encoder_ckpt_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to the encoder ckpt.", | |
) | |
parser.add_argument( | |
"--decoder_ckpt_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to the decoder ckpt.", | |
) | |
parser.add_argument( | |
"--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format." | |
) | |
args = parser.parse_args() | |
print("Loading the original state_dicts of the encoder and the decoder...") | |
encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path) | |
decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path) | |
print("Populating the state_dicts in the diffusers format...") | |
tiny_autoencoder = AutoencoderTiny() | |
new_state_dict = {} | |
# Modify the encoder state dict. | |
for k in encoder_state_dict: | |
new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]}) | |
# Modify the decoder state dict. | |
for k in decoder_state_dict: | |
layer_id = int(k.split(".")[0]) - 1 | |
new_k = str(layer_id) + "." + ".".join(k.split(".")[1:]) | |
new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]}) | |
# Assertion tests with the original implementation can be found here: | |
# https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f | |
tiny_autoencoder.load_state_dict(new_state_dict) | |
print("Population successful, serializing...") | |
tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors) | |