# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helper functions for manipulating collections of variables during training. """ import logging import re import tensorflow as tf from tensorflow.python.ops import variables as tf_variables slim = tf.contrib.slim # TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in # tensorflow/contrib/framework/python/ops/variables.py def filter_variables(variables, filter_regex_list, invert=False): """Filters out the variables matching the filter_regex. Filter out the variables whose name matches the any of the regular expressions in filter_regex_list and returns the remaining variables. Optionally, if invert=True, the complement set is returned. Args: variables: a list of tensorflow variables. filter_regex_list: a list of string regular expressions. invert: (boolean). If True, returns the complement of the filter set; that is, all variables matching filter_regex are kept and all others discarded. Returns: a list of filtered variables. """ kept_vars = [] variables_to_ignore_patterns = list(filter(None, filter_regex_list)) for var in variables: add = True for pattern in variables_to_ignore_patterns: if re.match(pattern, var.op.name): add = False break if add != invert: kept_vars.append(var) return kept_vars def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier): """Multiply gradients whose variable names match a regular expression. Args: grads_and_vars: A list of gradient to variable pairs (tuples). regex_list: A list of string regular expressions. multiplier: A (float) multiplier to apply to each gradient matching the regular expression. Returns: grads_and_vars: A list of gradient to variable pairs (tuples). """ variables = [pair[1] for pair in grads_and_vars] matching_vars = filter_variables(variables, regex_list, invert=True) for var in matching_vars: logging.info('Applying multiplier %f to variable [%s]', multiplier, var.op.name) grad_multipliers = {var: float(multiplier) for var in matching_vars} return slim.learning.multiply_gradients(grads_and_vars, grad_multipliers) def freeze_gradients_matching_regex(grads_and_vars, regex_list): """Freeze gradients whose variable names match a regular expression. Args: grads_and_vars: A list of gradient to variable pairs (tuples). regex_list: A list of string regular expressions. Returns: grads_and_vars: A list of gradient to variable pairs (tuples) that do not contain the variables and gradients matching the regex. """ variables = [pair[1] for pair in grads_and_vars] matching_vars = filter_variables(variables, regex_list, invert=True) kept_grads_and_vars = [pair for pair in grads_and_vars if pair[1] not in matching_vars] for var in matching_vars: logging.info('Freezing variable [%s]', var.op.name) return kept_grads_and_vars def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True): """Returns the subset of variables available in the checkpoint. Inspects given checkpoint and returns the subset of variables that are available in it. TODO(rathodv): force input and output to be a dictionary. Args: variables: a list or dictionary of variables to find in checkpoint. checkpoint_path: path to the checkpoint to restore variables from. include_global_step: whether to include `global_step` variable, if it exists. Default True. Returns: A list or dictionary of variables. Raises: ValueError: if `variables` is not a list or dict. """ if isinstance(variables, list): variable_names_map = {} for variable in variables: if isinstance(variable, tf_variables.PartitionedVariable): name = variable.name else: name = variable.op.name variable_names_map[name] = variable elif isinstance(variables, dict): variable_names_map = variables else: raise ValueError('`variables` is expected to be a list or dict.') ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path) ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map() if not include_global_step: ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None) vars_in_ckpt = {} for variable_name, variable in sorted(variable_names_map.items()): if variable_name in ckpt_vars_to_shape_map: if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list(): vars_in_ckpt[variable_name] = variable else: logging.warning('Variable [%s] is available in checkpoint, but has an ' 'incompatible shape with model variable. Checkpoint ' 'shape: [%s], model variable shape: [%s]. This ' 'variable will not be initialized from the checkpoint.', variable_name, ckpt_vars_to_shape_map[variable_name], variable.shape.as_list()) else: logging.warning('Variable [%s] is not available in checkpoint', variable_name) if isinstance(variables, list): return vars_in_ckpt.values() return vars_in_ckpt