Spaces:
Running
Running
File size: 5,658 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom version for quantized training and evaluation functions.
The main difference between this and the third_party graph_rewriter_builder.py
is that this version uses experimental_create_training_graph which allows the
customization of freeze_bn_delay.
"""
import re
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
def build(graph_rewriter_config,
quant_overrides_config=None,
is_training=True,
is_export=False):
"""Returns a function that modifies default graph based on options.
Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
quant_overrides_config: quant_overrides_pb2.QuantOverrides proto.
is_training: whether in training or eval mode.
is_export: whether exporting the graph.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')
graph = tf.get_default_graph()
# Insert custom quant ops.
if quant_overrides_config is not None:
input_to_ops_map = input_to_ops.InputToOps(graph)
for q in quant_overrides_config.quant_configs:
producer = graph.get_operation_by_name(q.op_name)
if producer is None:
raise ValueError('Op name does not exist in graph.')
context = _get_context_from_op(producer)
consumers = input_to_ops_map.ConsumerOperations(producer)
if q.fixed_range:
_insert_fixed_quant_op(
context,
q.quant_op_name,
producer,
consumers,
init_min=q.min,
init_max=q.max,
quant_delay=q.delay if is_training else 0)
else:
raise ValueError('Learned ranges are not yet supported.')
# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
contrib_quantize.experimental_create_training_graph(
input_graph=graph,
quant_delay=graph_rewriter_config.quantization.delay,
freeze_bn_delay=graph_rewriter_config.quantization.delay)
else:
contrib_quantize.experimental_create_eval_graph(
input_graph=graph,
quant_delay=graph_rewriter_config.quantization.delay
if not is_export else 0)
contrib_layers.summarize_collection('quant_vars')
return graph_rewrite_fn
def _get_context_from_op(op):
"""Gets the root context name from the op name."""
context_re = re.search(r'^(.*)/([^/]+)', op.name)
if context_re:
return context_re.group(1)
return ''
def _insert_fixed_quant_op(context,
name,
producer,
consumers,
init_min=-6.0,
init_max=6.0,
quant_delay=None):
"""Adds a fake quant op with fixed ranges.
Args:
context: The parent scope of the op to be quantized.
name: The name of the fake quant op.
producer: The producer op to be quantized.
consumers: The consumer ops to the producer op.
init_min: The minimum range for the fake quant op.
init_max: The maximum range for the fake quant op.
quant_delay: Number of steps to wait before activating the fake quant op.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
name_prefix = name if not context else context + '/' + name
inputs = producer.outputs[0]
quant = quant_ops.FixedQuantize(
inputs, init_min=init_min, init_max=init_max, scope=name_prefix)
if quant_delay and quant_delay > 0:
activate_quant = math_ops.greater_equal(
common.CreateOrGetQuantizationStep(),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
# of consumers.
if tensors_modified_count < len(consumers):
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
|