File size: 1,439 Bytes
b072d5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import click
def overwrite_first_stage(model_state_dict, vae_state_dict):
"""
Overwrite the First Stage Decoders.
From the new repo:
To keep compatibility with existing models,
only the decoder part was finetuned;
the checkpoints can be used as a drop-in replacement
for the existing autoencoder.
Sounds like we only need to change the decoder weights.
"""
target = "first_stage_model."
for key in model_state_dict.keys():
if target in key and ("decoder" in key or "encoder" in key):
matching_name = key.split(target)[1]
# double check this weight exists in the new vae
if matching_name in vae_state_dict:
model_state_dict[key] = vae_state_dict[matching_name]
else:
print(f"{key} Does not exist in the new VAE weights!")
return model_state_dict
@click.command()
@click.option("--base-model", type=str, default="sd-v1-5.ckpt")
@click.option("--vae", type=str, default="new_vae.ckpt")
@click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt")
def main(base_model, vae, output_name):
print("hello")
model = torch.load(base_model)
new_vae = torch.load(vae)
model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"])
print(f"Saving to {output_name}")
torch.save(model, output_name)
main()
|