File size: 380 Bytes
4c4f051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True)
args = parser.parse_args()
file = args.input

checkpoint = torch.load(file)
new_sd = dict()
for k in checkpoint.keys():
    if k != 'optimizer_states':
        new_sd[k] = checkpoint[k]

torch.save(new_sd, f'pruned-{file}')