Spaces:
Configuration error
Configuration error
from argparse import ArgumentParser | |
import torch | |
def simplify_pth(pth_name, project_name): | |
model_path = f'./checkpoints/{project_name}' | |
checkpoint_dict = torch.load(f'{model_path}/{pth_name}') | |
torch.save({'epoch': checkpoint_dict['epoch'], | |
'state_dict': checkpoint_dict['state_dict'], | |
'global_step': None, | |
'checkpoint_callback_best': None, | |
'optimizer_states': None, | |
'lr_schedulers': None | |
}, f'./clean_{pth_name}') | |
def main(): | |
parser = ArgumentParser() | |
parser.add_argument('--proj', type=str) | |
parser.add_argument('--steps', type=str) | |
args = parser.parse_args() | |
model_name = f"model_ckpt_steps_{args.steps}.ckpt" | |
simplify_pth(model_name, args.proj) | |
if __name__ == '__main__': | |
main() | |