|
from argparse import ArgumentParser |
|
|
|
import torch |
|
|
|
|
|
def simplify_pth(pth_name, project_name): |
|
model_path = f'./checkpoints/{project_name}' |
|
checkpoint_dict = torch.load(f'{model_path}/{pth_name}') |
|
torch.save({'epoch': checkpoint_dict['epoch'], |
|
'state_dict': checkpoint_dict['state_dict'], |
|
'global_step': None, |
|
'checkpoint_callback_best': None, |
|
'optimizer_states': None, |
|
'lr_schedulers': None |
|
}, f'./clean_{pth_name}') |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument('--proj', type=str) |
|
parser.add_argument('--steps', type=str) |
|
args = parser.parse_args() |
|
model_name = f"model_ckpt_steps_{args.steps}.ckpt" |
|
simplify_pth(model_name, args.proj) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|