NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom version for quantized training and evaluation functions.
The main difference between this and the third_party graph_rewriter_builder.py
is that this version uses experimental_create_training_graph which allows the
customization of freeze_bn_delay.
"""
import re
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
def build(graph_rewriter_config,
quant_overrides_config=None,
is_training=True,
is_export=False):
"""Returns a function that modifies default graph based on options.
Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
quant_overrides_config: quant_overrides_pb2.QuantOverrides proto.
is_training: whether in training or eval mode.
is_export: whether exporting the graph.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')
graph = tf.get_default_graph()
# Insert custom quant ops.
if quant_overrides_config is not None:
input_to_ops_map = input_to_ops.InputToOps(graph)
for q in quant_overrides_config.quant_configs:
producer = graph.get_operation_by_name(q.op_name)
if producer is None:
raise ValueError('Op name does not exist in graph.')
context = _get_context_from_op(producer)
consumers = input_to_ops_map.ConsumerOperations(producer)
if q.fixed_range:
_insert_fixed_quant_op(
context,
q.quant_op_name,
producer,
consumers,
init_min=q.min,
init_max=q.max,
quant_delay=q.delay if is_training else 0)
else:
raise ValueError('Learned ranges are not yet supported.')
# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
contrib_quantize.experimental_create_training_graph(
input_graph=graph,
quant_delay=graph_rewriter_config.quantization.delay,
freeze_bn_delay=graph_rewriter_config.quantization.delay)
else:
contrib_quantize.experimental_create_eval_graph(
input_graph=graph,
quant_delay=graph_rewriter_config.quantization.delay
if not is_export else 0)
contrib_layers.summarize_collection('quant_vars')
return graph_rewrite_fn
def _get_context_from_op(op):
"""Gets the root context name from the op name."""
context_re = re.search(r'^(.*)/([^/]+)', op.name)
if context_re:
return context_re.group(1)
return ''
def _insert_fixed_quant_op(context,
name,
producer,
consumers,
init_min=-6.0,
init_max=6.0,
quant_delay=None):
"""Adds a fake quant op with fixed ranges.
Args:
context: The parent scope of the op to be quantized.
name: The name of the fake quant op.
producer: The producer op to be quantized.
consumers: The consumer ops to the producer op.
init_min: The minimum range for the fake quant op.
init_max: The maximum range for the fake quant op.
quant_delay: Number of steps to wait before activating the fake quant op.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
name_prefix = name if not context else context + '/' + name
inputs = producer.outputs[0]
quant = quant_ops.FixedQuantize(
inputs, init_min=init_min, init_max=init_max, scope=name_prefix)
if quant_delay and quant_delay > 0:
activate_quant = math_ops.greater_equal(
common.CreateOrGetQuantizationStep(),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
# of consumers.
if tensors_modified_count < len(consumers):
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))