File size: 381 Bytes
a662214
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
sd = torch.load(model_path, map_location="cpu")
if "state_dict" not in sd:
    pruned_sd = {
        "state_dict": dict(),
    }
else:
    pruned_sd = dict()
for k in sd.keys():
    if k != "optimizer_states":
        if "state_dict" not in sd:
            pruned_sd["state_dict"][k] = sd[k]
        else:
            pruned_sd[k] = sd[k]
torch.save(pruned_sd, "model-pruned.ckpt")