File size: 5,658 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Custom version for quantized training and evaluation functions.

The main difference between this and the third_party graph_rewriter_builder.py
is that this version uses experimental_create_training_graph which allows the
customization of freeze_bn_delay.
"""

import re
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops


def build(graph_rewriter_config,
          quant_overrides_config=None,
          is_training=True,
          is_export=False):
  """Returns a function that modifies default graph based on options.

  Args:
    graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
    quant_overrides_config: quant_overrides_pb2.QuantOverrides proto.
    is_training: whether in training or eval mode.
    is_export: whether exporting the graph.
  """
  def graph_rewrite_fn():
    """Function to quantize weights and activation of the default graph."""
    if (graph_rewriter_config.quantization.weight_bits != 8 or
        graph_rewriter_config.quantization.activation_bits != 8):
      raise ValueError('Only 8bit quantization is supported')

    graph = tf.get_default_graph()

    # Insert custom quant ops.
    if quant_overrides_config is not None:
      input_to_ops_map = input_to_ops.InputToOps(graph)
      for q in quant_overrides_config.quant_configs:
        producer = graph.get_operation_by_name(q.op_name)
        if producer is None:
          raise ValueError('Op name does not exist in graph.')
        context = _get_context_from_op(producer)
        consumers = input_to_ops_map.ConsumerOperations(producer)
        if q.fixed_range:
          _insert_fixed_quant_op(
              context,
              q.quant_op_name,
              producer,
              consumers,
              init_min=q.min,
              init_max=q.max,
              quant_delay=q.delay if is_training else 0)
        else:
          raise ValueError('Learned ranges are not yet supported.')

    # Quantize the graph by inserting quantize ops for weights and activations
    if is_training:
      contrib_quantize.experimental_create_training_graph(
          input_graph=graph,
          quant_delay=graph_rewriter_config.quantization.delay,
          freeze_bn_delay=graph_rewriter_config.quantization.delay)
    else:
      contrib_quantize.experimental_create_eval_graph(
          input_graph=graph,
          quant_delay=graph_rewriter_config.quantization.delay
          if not is_export else 0)

    contrib_layers.summarize_collection('quant_vars')

  return graph_rewrite_fn


def _get_context_from_op(op):
  """Gets the root context name from the op name."""
  context_re = re.search(r'^(.*)/([^/]+)', op.name)
  if context_re:
    return context_re.group(1)
  return ''


def _insert_fixed_quant_op(context,
                           name,
                           producer,
                           consumers,
                           init_min=-6.0,
                           init_max=6.0,
                           quant_delay=None):
  """Adds a fake quant op with fixed ranges.

  Args:
    context: The parent scope of the op to be quantized.
    name: The name of the fake quant op.
    producer: The producer op to be quantized.
    consumers: The consumer ops to the producer op.
    init_min: The minimum range for the fake quant op.
    init_max: The maximum range for the fake quant op.
    quant_delay: Number of steps to wait before activating the fake quant op.

  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
  name_prefix = name if not context else context + '/' + name
  inputs = producer.outputs[0]
  quant = quant_ops.FixedQuantize(
      inputs, init_min=init_min, init_max=init_max, scope=name_prefix)

  if quant_delay and quant_delay > 0:
    activate_quant = math_ops.greater_equal(
        common.CreateOrGetQuantizationStep(),
        quant_delay,
        name=name_prefix + '/activate_quant')
    quant = control_flow_ops.cond(
        activate_quant,
        lambda: quant,
        lambda: inputs,
        name=name_prefix + '/delayed_quant')

  if consumers:
    tensors_modified_count = common.RerouteTensor(
        quant, inputs, can_modify=consumers)
    # Some operations can have multiple output tensors going to the same
    # consumer. Since consumers is a set, we need to ensure that
    # tensors_modified_count is greater than or equal to the length of the set
    # of consumers.
    if tensors_modified_count < len(consumers):
      raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
          [consumer.name for consumer in consumers]))