adaface-animate / scripts /repl_vae.py
adaface-neurips's picture
re-init
02cc20b
raw
history blame
No virus
2.42 kB
import torch
from safetensors.torch import load_file as safetensors_load_file
from safetensors.torch import save_file as safetensors_save_file
import sys, argparse
def load_ckpt(ckpt_filepath):
print(f"Loading model from {ckpt_filepath}")
if ckpt_filepath.endswith(".safetensors"):
state_dict = safetensors_load_file(ckpt_filepath, device="cpu")
ckpt = None
else:
ckpt = torch.load(ckpt_filepath, map_location="cpu")
state_dict = ckpt["state_dict"]
return ckpt, state_dict
def save_ckpt(ckpt, ckpt_state_dict, ckpt_filepath):
if ckpt_filepath.endswith(".safetensors"):
safetensors_save_file(ckpt_state_dict, ckpt_filepath)
else:
if ckpt is not None:
torch.save(ckpt, ckpt_filepath)
else:
torch.save(ckpt_state_dict, ckpt_filepath)
print(f"Saved to {ckpt_filepath}")
def load_two_models(base_ckpt_filepath, vae_ckpt_filepath):
base_ckpt, base_state_dict = load_ckpt(base_ckpt_filepath)
_, vae_state_dict = load_ckpt(vae_ckpt_filepath)
# Other fields in sd_ckpt are also needed when saving the checkpoint.
# So return the whole sd_ckpt.
return base_ckpt, base_state_dict, vae_state_dict
parser = argparse.ArgumentParser()
parser.add_argument("--base_ckpt", type=str, required=True, help="Path to the base checkpoint")
parser.add_argument("--vae_ckpt", type=str, required=True, help="Path to the checkpoint providing vae")
parser.add_argument("--out_ckpt", type=str, required=True, help="Path to the output checkpoint")
args = parser.parse_args()
base_ckpt, base_state_dict, vae_state_dict = load_two_models(args.base_ckpt, args.vae_ckpt)
# base_state_dict = sd_ckpt["state_dict"]
repl_count = 0
for k in base_state_dict:
if k.startswith("first_stage_model."):
k2 = k.replace("first_stage_model.", "")
if k2 not in vae_state_dict:
print(f"!!!! '{k2}' not in VAE checkpoint")
continue
if base_state_dict[k].shape != vae_state_dict[k2].shape:
print(f"!!!! '{k}' shape mismatch: {base_state_dict[k].shape} vs {vae_state_dict[k2].shape} !!!!")
continue
print(k)
base_state_dict[k] = vae_state_dict[k2]
repl_count += 1
if repl_count > 0:
print(f"{repl_count} parameters replaced")
save_ckpt(base_ckpt, base_state_dict, args.out_ckpt)
else:
print("ERROR: No parameter replaced")