PandA / networks /genforce /convert_model.py
james-oldfield's picture
Upload 194 files
2a76164
raw
history blame contribute delete
No virus
3.95 kB
"""Script to convert officially released models to match this repository."""
import argparse
from converters import convert_pggan_weight
from converters import convert_stylegan_weight
from converters import convert_stylegan2_weight
from converters import convert_stylegan2ada_tf_weight
from converters import convert_stylegan2ada_pth_weight
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Convert pre-trained models.')
parser.add_argument('model_type', type=str,
choices=['pggan', 'stylegan', 'stylegan2',
'stylegan2ada_tf', 'stylegan2ada_pth'],
help='Type of the model to convert')
parser.add_argument('--source_model_path', type=str, required=True,
help='Path to load the model for conversion.')
parser.add_argument('--target_model_path', type=str, default=None,
help='Path to save the converted model. If not '
'specified, the model will be saved to the same '
'directory of the source model.')
parser.add_argument('--test_num', type=int, default=10,
help='Number of test samples used to check the '
'precision of the converted model. (default: 10)')
parser.add_argument('--save_test_image', action='store_true',
help='Whether to save the test image. (default: False)')
parser.add_argument('--verbose_log', action='store_true',
help='Whether to print verbose log. (default: False)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
if args.target_model_path is None:
args.target_model_path = args.source_model_path.replace('.pkl', '.pth')
if args.model_type == 'pggan':
convert_pggan_weight(tf_weight_path=args.source_model_path,
pth_weight_path=args.target_model_path,
test_num=args.test_num,
save_test_image=args.save_test_image,
verbose=args.verbose_log)
elif args.model_type == 'stylegan':
convert_stylegan_weight(tf_weight_path=args.source_model_path,
pth_weight_path=args.target_model_path,
test_num=args.test_num,
save_test_image=args.save_test_image,
verbose=args.verbose_log)
elif args.model_type == 'stylegan2':
convert_stylegan2_weight(tf_weight_path=args.source_model_path,
pth_weight_path=args.target_model_path,
test_num=args.test_num,
save_test_image=args.save_test_image,
verbose=args.verbose_log)
elif args.model_type == 'stylegan2ada_tf':
convert_stylegan2ada_tf_weight(tf_weight_path=args.source_model_path,
pth_weight_path=args.target_model_path,
test_num=args.test_num,
save_test_image=args.save_test_image,
verbose=args.verbose_log)
elif args.model_type == 'stylegan2ada_pth':
convert_stylegan2ada_pth_weight(src_weight_path=args.source_model_path,
dst_weight_path=args.target_model_path,
test_num=args.test_num,
save_test_image=args.save_test_image,
verbose=args.verbose_log)
else:
raise NotImplementedError(f'Model type `{args.model_type}` is not '
f'supported!')
if __name__ == '__main__':
main()