Spaces:
Runtime error
Runtime error
File size: 636 Bytes
26f34b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from torch import load, save
if __name__ == '__main__':
print("优化器通常不会被用于推理阶段,如果只用于推理可以去除优化器以减小模型体积\n")
input_path = input("请输入模型的路径:")
output_path = f"{input_path.split('.')[0]}_inference.pth"
checkpoint_dict = load(input_path, map_location='cpu')
checkpoint_dict_new = {}
for k, v in checkpoint_dict.items():
if k == "optimizer":
print(f"remove optimizer")
continue
checkpoint_dict_new[k] = v
save(checkpoint_dict_new, output_path)
print("finish")
print(output_path)
|