InstructDiffusion / scripts /convert_ckpt.py
Kayson's picture
sync
7ae68fe
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Zigang Geng (zigang@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import annotations
import sys
import torch
from argparse import ArgumentParser
from omegaconf import OmegaConf
sys.path.append("./stable_diffusion")
from stable_diffusion.ldm.util import instantiate_from_config
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", default="configs/instruct_diffusion.yaml", type=str)
parser.add_argument("--ema-ckpt", default="logs/instruct_diffusion/checkpoints/ckpt_epoch_200/state.pth", type=str)
parser.add_argument("--vae-ckpt", default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt", type=str)
parser.add_argument("--out-ckpt", default="checkpoints/v1-5-pruned-emaonly-adaption-task.ckpt", type=str)
args = parser.parse_args()
config = OmegaConf.load(args.config)
model = instantiate_from_config(config.model)
ema_ckpt = torch.load(args.ema_ckpt, map_location="cpu")
all_keys = [key for key, value in model.named_parameters()]
all_keys_rmv = [key.replace('.','') for key in all_keys]
new_ema_ckpt = {}
for k, v in ema_ckpt['model_ema'].items():
try:
k_index = all_keys_rmv.index(k)
new_ema_ckpt[all_keys[k_index]] = v
except:
print(k+' is not in the list.')
vae_ckpt = torch.load(args.vae_ckpt, map_location="cpu")
for k, v in vae_ckpt['state_dict'].items():
if k not in new_ema_ckpt and k in all_keys:
new_ema_ckpt[k] = v
checkpoint = {'state_dict': new_ema_ckpt}
with open(args.out_ckpt, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
print('Converted successfully, the new checkpoint has been saved to ' + str(args.out_ckpt))