import torch import click def overwrite_first_stage(model_state_dict, vae_state_dict): """ Overwrite the First Stage Decoders. From the new repo: To keep compatibility with existing models, only the decoder part was finetuned; the checkpoints can be used as a drop-in replacement for the existing autoencoder. Sounds like we only need to change the decoder weights. """ target = "first_stage_model." for key in model_state_dict.keys(): if target in key and ("decoder" in key or "encoder" in key): matching_name = key.split(target)[1] # double check this weight exists in the new vae if matching_name in vae_state_dict: model_state_dict[key] = vae_state_dict[matching_name] else: print(f"{key} Does not exist in the new VAE weights!") return model_state_dict @click.command() @click.option("--base-model", type=str, default="sd-v1-5.ckpt") @click.option("--vae", type=str, default="new_vae.ckpt") @click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt") def main(base_model, vae, output_name): print("hello") model = torch.load(base_model) new_vae = torch.load(vae) model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"]) print(f"Saving to {output_name}") torch.save(model, output_name) main()