# Adapted from https://gist.github.com/qqaatw/82b47c2b3da602fa1df604167bfcb9b0 import getopt import sys import re import tensorflow.compat.v1 as tf usage_str = ('python tensorflow_rename_variables.py ' '--checkpoint_dir=path/to/dir/ --replace_from=substr ' '--replace_to=substr --add_prefix=abc --dry_run') find_usage_str = ('python tensorflow_rename_variables.py ' '--checkpoint_dir=path/to/dir/ --find_str=[\'!\']substr') comp_usage_str = ('python tensorflow_rename_variables.py ' '--checkpoint_dir=path/to/dir/ ' '--checkpoint_dir2=path/to/dir/') def print_usage_str(): print('Please specify a checkpoint_dir. Usage:') print('%s\nor\n%s\nor\n%s' % (usage_str, find_usage_str, comp_usage_str)) print('Note: checkpoint_dir should be a *DIR*, not a file') def compare(checkpoint_dir, checkpoint_dir2): import difflib with tf.Session(): list1 = [el1 for (el1, el2) in tf.train.list_variables(checkpoint_dir)] list2 = [el1 for (el1, el2) in tf.train.list_variables(checkpoint_dir2)] for k1 in list1: if k1 in list2: continue else: print('{} close matches: {}'.format( k1, difflib.get_close_matches(k1, list2))) def find(checkpoint_dir, find_str): with tf.Session(): negate = find_str.startswith('!') if negate: find_str = find_str[1:] for var_name, _ in tf.train.list_variables(checkpoint_dir): if negate and find_str not in var_name: print('%s missing from %s.' % (find_str, var_name)) if not negate and find_str in var_name: print('Found %s in %s.' % (find_str, var_name)) def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run): checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) print('print: ', checkpoint) with tf.Session() as sess: for var_name, _ in tf.train.list_variables(checkpoint_dir): # Load the variable var= tf.train.load_variable(checkpoint_dir, var_name) # Set the new name if None not in [replace_from, replace_to]: new_name = re.sub(replace_from, replace_to, var_name) if add_prefix: new_name = add_prefix + new_name if dry_run: print('%s would be renamed to %s.' % (var_name, new_name)) else: if var_name != new_name: print('Renaming %s to %s.' % (var_name, new_name)) # Create the variable, potentially renaming it var = tf.Variable(var, name=new_name) if not dry_run: # Save the variables saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) #saver.save(sess, checkpoint.model_checkpoint_path) saver.save(sess, "renamed-model.ckpt") def main(argv): checkpoint_dir = None checkpoint_dir2 = None replace_from = None replace_to = None add_prefix = None dry_run = False find_str = None try: opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', 'replace_to=', 'add_prefix=', 'dry_run', 'find_str=', 'checkpoint_dir2=']) except getopt.GetoptError as e: print(e) print_usage_str() sys.exit(2) for opt, arg in opts: if opt in ('-h', '--help'): print(usage_str) sys.exit() elif opt == '--checkpoint_dir': checkpoint_dir = arg elif opt == '--checkpoint_dir2': checkpoint_dir2 = arg elif opt == '--replace_from': replace_from = arg elif opt == '--replace_to': replace_to = arg elif opt == '--add_prefix': add_prefix = arg elif opt == '--dry_run': dry_run = True elif opt == '--find_str': find_str = arg if not checkpoint_dir: print_usage_str() sys.exit(2) if checkpoint_dir2: compare(checkpoint_dir, checkpoint_dir2) elif find_str: find(checkpoint_dir, find_str) else: rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run) if __name__ == '__main__': main(sys.argv[1:])