# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Custom version for quantized training and evaluation functions. The main difference between this and the third_party graph_rewriter_builder.py is that this version uses experimental_create_training_graph which allows the customization of freeze_bn_delay. """ import re import tensorflow.compat.v1 as tf from tensorflow.contrib import layers as contrib_layers from tensorflow.contrib import quantize as contrib_quantize from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops def build(graph_rewriter_config, quant_overrides_config=None, is_training=True, is_export=False): """Returns a function that modifies default graph based on options. Args: graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. quant_overrides_config: quant_overrides_pb2.QuantOverrides proto. is_training: whether in training or eval mode. is_export: whether exporting the graph. """ def graph_rewrite_fn(): """Function to quantize weights and activation of the default graph.""" if (graph_rewriter_config.quantization.weight_bits != 8 or graph_rewriter_config.quantization.activation_bits != 8): raise ValueError('Only 8bit quantization is supported') graph = tf.get_default_graph() # Insert custom quant ops. if quant_overrides_config is not None: input_to_ops_map = input_to_ops.InputToOps(graph) for q in quant_overrides_config.quant_configs: producer = graph.get_operation_by_name(q.op_name) if producer is None: raise ValueError('Op name does not exist in graph.') context = _get_context_from_op(producer) consumers = input_to_ops_map.ConsumerOperations(producer) if q.fixed_range: _insert_fixed_quant_op( context, q.quant_op_name, producer, consumers, init_min=q.min, init_max=q.max, quant_delay=q.delay if is_training else 0) else: raise ValueError('Learned ranges are not yet supported.') # Quantize the graph by inserting quantize ops for weights and activations if is_training: contrib_quantize.experimental_create_training_graph( input_graph=graph, quant_delay=graph_rewriter_config.quantization.delay, freeze_bn_delay=graph_rewriter_config.quantization.delay) else: contrib_quantize.experimental_create_eval_graph( input_graph=graph, quant_delay=graph_rewriter_config.quantization.delay if not is_export else 0) contrib_layers.summarize_collection('quant_vars') return graph_rewrite_fn def _get_context_from_op(op): """Gets the root context name from the op name.""" context_re = re.search(r'^(.*)/([^/]+)', op.name) if context_re: return context_re.group(1) return '' def _insert_fixed_quant_op(context, name, producer, consumers, init_min=-6.0, init_max=6.0, quant_delay=None): """Adds a fake quant op with fixed ranges. Args: context: The parent scope of the op to be quantized. name: The name of the fake quant op. producer: The producer op to be quantized. consumers: The consumer ops to the producer op. init_min: The minimum range for the fake quant op. init_max: The maximum range for the fake quant op. quant_delay: Number of steps to wait before activating the fake quant op. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ name_prefix = name if not context else context + '/' + name inputs = producer.outputs[0] quant = quant_ops.FixedQuantize( inputs, init_min=init_min, init_max=init_max, scope=name_prefix) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond( activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') if consumers: tensors_modified_count = common.RerouteTensor( quant, inputs, can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set # of consumers. if tensors_modified_count < len(consumers): raise ValueError('No inputs quantized for ops: [%s]' % ', '.join( [consumer.name for consumer in consumers]))