Spaces:
Build error
Build error
"""Script to convert officially released models to match this repository.""" | |
import argparse | |
from converters import convert_pggan_weight | |
from converters import convert_stylegan_weight | |
from converters import convert_stylegan2_weight | |
from converters import convert_stylegan2ada_tf_weight | |
from converters import convert_stylegan2ada_pth_weight | |
def parse_args(): | |
"""Parses arguments.""" | |
parser = argparse.ArgumentParser(description='Convert pre-trained models.') | |
parser.add_argument('model_type', type=str, | |
choices=['pggan', 'stylegan', 'stylegan2', | |
'stylegan2ada_tf', 'stylegan2ada_pth'], | |
help='Type of the model to convert') | |
parser.add_argument('--source_model_path', type=str, required=True, | |
help='Path to load the model for conversion.') | |
parser.add_argument('--target_model_path', type=str, default=None, | |
help='Path to save the converted model. If not ' | |
'specified, the model will be saved to the same ' | |
'directory of the source model.') | |
parser.add_argument('--test_num', type=int, default=10, | |
help='Number of test samples used to check the ' | |
'precision of the converted model. (default: 10)') | |
parser.add_argument('--save_test_image', action='store_true', | |
help='Whether to save the test image. (default: False)') | |
parser.add_argument('--verbose_log', action='store_true', | |
help='Whether to print verbose log. (default: False)') | |
return parser.parse_args() | |
def main(): | |
"""Main function.""" | |
args = parse_args() | |
if args.target_model_path is None: | |
args.target_model_path = args.source_model_path.replace('.pkl', '.pth') | |
if args.model_type == 'pggan': | |
convert_pggan_weight(tf_weight_path=args.source_model_path, | |
pth_weight_path=args.target_model_path, | |
test_num=args.test_num, | |
save_test_image=args.save_test_image, | |
verbose=args.verbose_log) | |
elif args.model_type == 'stylegan': | |
convert_stylegan_weight(tf_weight_path=args.source_model_path, | |
pth_weight_path=args.target_model_path, | |
test_num=args.test_num, | |
save_test_image=args.save_test_image, | |
verbose=args.verbose_log) | |
elif args.model_type == 'stylegan2': | |
convert_stylegan2_weight(tf_weight_path=args.source_model_path, | |
pth_weight_path=args.target_model_path, | |
test_num=args.test_num, | |
save_test_image=args.save_test_image, | |
verbose=args.verbose_log) | |
elif args.model_type == 'stylegan2ada_tf': | |
convert_stylegan2ada_tf_weight(tf_weight_path=args.source_model_path, | |
pth_weight_path=args.target_model_path, | |
test_num=args.test_num, | |
save_test_image=args.save_test_image, | |
verbose=args.verbose_log) | |
elif args.model_type == 'stylegan2ada_pth': | |
convert_stylegan2ada_pth_weight(src_weight_path=args.source_model_path, | |
dst_weight_path=args.target_model_path, | |
test_num=args.test_num, | |
save_test_image=args.save_test_image, | |
verbose=args.verbose_log) | |
else: | |
raise NotImplementedError(f'Model type `{args.model_type}` is not ' | |
f'supported!') | |
if __name__ == '__main__': | |
main() | |