import torch | |
vae_path = 'models/vqgan_cfw_00011.ckpt' | |
with open(vae_path, 'rb') as f: | |
vae_ckpt = torch.load(f, map_location='cpu') | |
prune_keys = [] | |
for k, v in vae_ckpt['state_dict'].items(): | |
if 'decoder.fusion_layer' in k: | |
prune_keys.append(k) | |
print(k) | |
vae_cfw = {} | |
for k in prune_keys: | |
vae_cfw[k] = vae_ckpt['state_dict'][k] | |
del vae_ckpt['state_dict'][k] | |
torch.save(vae_ckpt, 'models/vqgan_cfw_00011_vae_only.ckpt') | |
torch.save(vae_cfw, 'models/vqgan_cfw_00011_cfw_only.ckpt') | |