File size: 824 Bytes
ed1cdd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from argparse import ArgumentParser

import torch


def simplify_pth(pth_name, project_name):
    model_path = f'./checkpoints/{project_name}'
    checkpoint_dict = torch.load(f'{model_path}/{pth_name}')
    torch.save({'epoch': checkpoint_dict['epoch'],
                'state_dict': checkpoint_dict['state_dict'],
                'global_step': None,
                'checkpoint_callback_best': None,
                'optimizer_states': None,
                'lr_schedulers': None
                }, f'./clean_{pth_name}')


def main():
    parser = ArgumentParser()
    parser.add_argument('--proj', type=str)
    parser.add_argument('--steps', type=str)
    args = parser.parse_args()
    model_name = f"model_ckpt_steps_{args.steps}.ckpt"
    simplify_pth(model_name, args.proj)


if __name__ == '__main__':
    main()