File size: 471 Bytes
8121fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import argparse
import torch


def main(checkpoint):
    state_dict = torch.load(checkpoint, map_location="cpu")
    if "optimizer" in state_dict:
        del state_dict["optimizer"]
    if "lr_scheduler" in state_dict:
        del state_dict["lr_scheduler"]
    torch.save(state_dict, checkpoint)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=str)
    args = parser.parse_args()
    main(args.checkpoint)