sd-1-5-mse-vae / swap_vae.py
nousr's picture
Upload swap_vae.py with huggingface_hub
b072d5b
import torch
import click
def overwrite_first_stage(model_state_dict, vae_state_dict):
"""
Overwrite the First Stage Decoders.
From the new repo:
To keep compatibility with existing models,
only the decoder part was finetuned;
the checkpoints can be used as a drop-in replacement
for the existing autoencoder.
Sounds like we only need to change the decoder weights.
"""
target = "first_stage_model."
for key in model_state_dict.keys():
if target in key and ("decoder" in key or "encoder" in key):
matching_name = key.split(target)[1]
# double check this weight exists in the new vae
if matching_name in vae_state_dict:
model_state_dict[key] = vae_state_dict[matching_name]
else:
print(f"{key} Does not exist in the new VAE weights!")
return model_state_dict
@click.command()
@click.option("--base-model", type=str, default="sd-v1-5.ckpt")
@click.option("--vae", type=str, default="new_vae.ckpt")
@click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt")
def main(base_model, vae, output_name):
print("hello")
model = torch.load(base_model)
new_vae = torch.load(vae)
model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"])
print(f"Saving to {output_name}")
torch.save(model, output_name)
main()