Spaces:
Build error
Build error
""" | |
@date: 2021/11/22 | |
@description: Conversion training ckpt into inference ckpt | |
""" | |
import argparse | |
import os | |
import torch | |
from config.defaults import merge_from_file | |
def parse_option(): | |
parser = argparse.ArgumentParser(description='Conversion training ckpt into inference ckpt') | |
parser.add_argument('--cfg', | |
type=str, | |
required=True, | |
metavar='FILE', | |
help='path of config file') | |
parser.add_argument('--output_path', | |
type=str, | |
help='path of output ckpt') | |
args = parser.parse_args() | |
print("arguments:") | |
for arg in vars(args): | |
print(arg, ":", getattr(args, arg)) | |
print("-" * 50) | |
return args | |
def convert_ckpt(): | |
args = parse_option() | |
config = merge_from_file(args.cfg) | |
ck_dir = os.path.join("checkpoints", f"{config.MODEL.ARGS[0]['decoder_name']}_{config.MODEL.ARGS[0]['output_name']}_Net", | |
config.TAG) | |
print(f"Processing {ck_dir}") | |
model_paths = [name for name in os.listdir(ck_dir) if '_best_' in name] | |
if len(model_paths) == 0: | |
print("Not find best ckpt") | |
return | |
model_path = os.path.join(ck_dir, model_paths[0]) | |
print(f"Loading {model_path}") | |
checkpoint = torch.load(model_path, map_location=torch.device('cuda:0')) | |
net = checkpoint['net'] | |
output_path = None | |
if args.output_path is None: | |
output_path = os.path.join(ck_dir, 'best.pkl') | |
else: | |
output_path = args.output_path | |
if output_path is None: | |
print("Output path is invalid") | |
print(f"Save on: {output_path}") | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
torch.save(net, output_path) | |
if __name__ == '__main__': | |
convert_ckpt() | |