File size: 1,439 Bytes
b072d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import click

def overwrite_first_stage(model_state_dict, vae_state_dict):
    """
    Overwrite the First Stage Decoders.

    From the new repo:
         To keep compatibility with existing models, 
         only the decoder part was finetuned; 
         the checkpoints can be used as a drop-in replacement 
         for the existing autoencoder.

    Sounds like we only need to change the decoder weights.
    """

    target = "first_stage_model."
    for key in model_state_dict.keys():
        if target in key and ("decoder" in key or "encoder" in key):
            matching_name = key.split(target)[1]

            # double check this weight exists in the new vae
            if matching_name in vae_state_dict:
                model_state_dict[key] = vae_state_dict[matching_name]
            else:
                print(f"{key} Does not exist in the new VAE weights!")
    
    return model_state_dict

@click.command()
@click.option("--base-model", type=str, default="sd-v1-5.ckpt")
@click.option("--vae", type=str, default="new_vae.ckpt")
@click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt")
def main(base_model, vae, output_name):
    print("hello")
    model = torch.load(base_model)
    new_vae = torch.load(vae)

    model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"])

    print(f"Saving to {output_name}")
    torch.save(model, output_name)


main()