import torch import argparse def prune_ckpt(ckpt_path, save_path): raw = torch.load(ckpt_path, map_location=torch.device('cpu')) state_dict = raw["state_dict"] torch.save(state_dict, save_path) if __name__ == '__main__': args = argparse.ArgumentParser() args.add_argument('--ckpt_path', type=str) args.add_argument('--save_path', type=str) args = args.parse_args()