ddd
add .gitattributes
871a48f
raw
history blame
484 Bytes
import sys
import torch
if __name__ == '__main__':
ckpt_path = sys.argv[1]
checkpoint = torch.load(ckpt_path, map_location='cpu')
print(checkpoint['state_dict'].keys())
if 'model' in checkpoint['state_dict']:
checkpoint = {'state_dict': {'model': checkpoint['state_dict']['model']}}
else:
checkpoint = {'state_dict': {'model_gen': checkpoint['state_dict']['model_gen']}}
torch.save(checkpoint, ckpt_path, _use_new_zipfile_serialization=False)