nousr commited on
Commit
b072d5b
1 Parent(s): da548ae

Upload swap_vae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. swap_vae.py +45 -0
swap_vae.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import click
3
+
4
+ def overwrite_first_stage(model_state_dict, vae_state_dict):
5
+ """
6
+ Overwrite the First Stage Decoders.
7
+
8
+ From the new repo:
9
+ To keep compatibility with existing models,
10
+ only the decoder part was finetuned;
11
+ the checkpoints can be used as a drop-in replacement
12
+ for the existing autoencoder.
13
+
14
+ Sounds like we only need to change the decoder weights.
15
+ """
16
+
17
+ target = "first_stage_model."
18
+ for key in model_state_dict.keys():
19
+ if target in key and ("decoder" in key or "encoder" in key):
20
+ matching_name = key.split(target)[1]
21
+
22
+ # double check this weight exists in the new vae
23
+ if matching_name in vae_state_dict:
24
+ model_state_dict[key] = vae_state_dict[matching_name]
25
+ else:
26
+ print(f"{key} Does not exist in the new VAE weights!")
27
+
28
+ return model_state_dict
29
+
30
+ @click.command()
31
+ @click.option("--base-model", type=str, default="sd-v1-5.ckpt")
32
+ @click.option("--vae", type=str, default="new_vae.ckpt")
33
+ @click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt")
34
+ def main(base_model, vae, output_name):
35
+ print("hello")
36
+ model = torch.load(base_model)
37
+ new_vae = torch.load(vae)
38
+
39
+ model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"])
40
+
41
+ print(f"Saving to {output_name}")
42
+ torch.save(model, output_name)
43
+
44
+
45
+ main()