File size: 636 Bytes
26f34b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch import load, save

if __name__ == '__main__':
    print("优化器通常不会被用于推理阶段,如果只用于推理可以去除优化器以减小模型体积\n")
    input_path = input("请输入模型的路径:")
    output_path = f"{input_path.split('.')[0]}_inference.pth"
    checkpoint_dict = load(input_path, map_location='cpu')
    checkpoint_dict_new = {}
    for k, v in checkpoint_dict.items():
        if k == "optimizer":
            print(f"remove optimizer")
            continue
        checkpoint_dict_new[k] = v
    save(checkpoint_dict_new, output_path)
    print("finish")
    print(output_path)