import torch import argparse parser = argparse.ArgumentParser() parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True) args = parser.parse_args() file = args.input checkpoint = torch.load(file) new_sd = dict() for k in checkpoint.keys(): if k != 'optimizer_states': new_sd[k] = checkpoint[k] torch.save(new_sd, f'pruned-{file}')