|
import torch |
|
import click |
|
|
|
def overwrite_first_stage(model_state_dict, vae_state_dict): |
|
""" |
|
Overwrite the First Stage Decoders. |
|
|
|
From the new repo: |
|
To keep compatibility with existing models, |
|
only the decoder part was finetuned; |
|
the checkpoints can be used as a drop-in replacement |
|
for the existing autoencoder. |
|
|
|
Sounds like we only need to change the decoder weights. |
|
""" |
|
|
|
target = "first_stage_model." |
|
for key in model_state_dict.keys(): |
|
if target in key and ("decoder" in key or "encoder" in key): |
|
matching_name = key.split(target)[1] |
|
|
|
|
|
if matching_name in vae_state_dict: |
|
model_state_dict[key] = vae_state_dict[matching_name] |
|
else: |
|
print(f"{key} Does not exist in the new VAE weights!") |
|
|
|
return model_state_dict |
|
|
|
@click.command() |
|
@click.option("--base-model", type=str, default="sd-v1-5.ckpt") |
|
@click.option("--vae", type=str, default="new_vae.ckpt") |
|
@click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt") |
|
def main(base_model, vae, output_name): |
|
print("hello") |
|
model = torch.load(base_model) |
|
new_vae = torch.load(vae) |
|
|
|
model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"]) |
|
|
|
print(f"Saving to {output_name}") |
|
torch.save(model, output_name) |
|
|
|
|
|
main() |
|
|