Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------------ | |
# Copyright (c) Microsoft | |
# Licensed under the MIT License. | |
# Written by Zigang Geng (zigang@mail.ustc.edu.cn) | |
# ------------------------------------------------------------------------------ | |
from __future__ import annotations | |
import sys | |
import torch | |
from argparse import ArgumentParser | |
from omegaconf import OmegaConf | |
sys.path.append("./stable_diffusion") | |
from stable_diffusion.ldm.util import instantiate_from_config | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--config", default="configs/instruct_diffusion.yaml", type=str) | |
parser.add_argument("--ema-ckpt", default="logs/instruct_diffusion/checkpoints/ckpt_epoch_200/state.pth", type=str) | |
parser.add_argument("--vae-ckpt", default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt", type=str) | |
parser.add_argument("--out-ckpt", default="checkpoints/v1-5-pruned-emaonly-adaption-task.ckpt", type=str) | |
args = parser.parse_args() | |
config = OmegaConf.load(args.config) | |
model = instantiate_from_config(config.model) | |
ema_ckpt = torch.load(args.ema_ckpt, map_location="cpu") | |
all_keys = [key for key, value in model.named_parameters()] | |
all_keys_rmv = [key.replace('.','') for key in all_keys] | |
new_ema_ckpt = {} | |
for k, v in ema_ckpt['model_ema'].items(): | |
try: | |
k_index = all_keys_rmv.index(k) | |
new_ema_ckpt[all_keys[k_index]] = v | |
except: | |
print(k+' is not in the list.') | |
vae_ckpt = torch.load(args.vae_ckpt, map_location="cpu") | |
for k, v in vae_ckpt['state_dict'].items(): | |
if k not in new_ema_ckpt and k in all_keys: | |
new_ema_ckpt[k] = v | |
checkpoint = {'state_dict': new_ema_ckpt} | |
with open(args.out_ckpt, 'wb') as f: | |
torch.save(checkpoint, f) | |
f.flush() | |
print('Converted successfully, the new checkpoint has been saved to ' + str(args.out_ckpt)) |