Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""BottleneckConvLSTMCell implementation.""" | |
import functools | |
import tensorflow.compat.v1 as tf | |
import tf_slim as slim | |
from tensorflow.contrib import rnn as contrib_rnn | |
from tensorflow.contrib.framework.python.ops import variables as contrib_variables | |
import lstm_object_detection.lstm.utils as lstm_utils | |
class BottleneckConvLSTMCell(contrib_rnn.RNNCell): | |
"""Basic LSTM recurrent network cell using separable convolutions. | |
The implementation is based on: | |
Mobile Video Object Detection with Temporally-Aware Feature Maps | |
https://arxiv.org/abs/1711.06368. | |
We add forget_bias (default: 1) to the biases of the forget gate in order to | |
reduce the scale of forgetting in the beginning of the training. | |
This LSTM first projects inputs to the size of the output before doing gate | |
computations. This saves params unless the input is less than a third of the | |
state size channel-wise. | |
""" | |
def __init__(self, | |
filter_size, | |
output_size, | |
num_units, | |
forget_bias=1.0, | |
activation=tf.tanh, | |
flatten_state=False, | |
clip_state=False, | |
output_bottleneck=False, | |
pre_bottleneck=False, | |
visualize_gates=False): | |
"""Initializes the basic LSTM cell. | |
Args: | |
filter_size: collection, conv filter size. | |
output_size: collection, the width/height dimensions of the cell/output. | |
num_units: int, The number of channels in the LSTM cell. | |
forget_bias: float, The bias added to forget gates (see above). | |
activation: Activation function of the inner states. | |
flatten_state: if True, state tensor will be flattened and stored as a 2-d | |
tensor. Use for exporting the model to tfmini. | |
clip_state: if True, clip state between [-6, 6]. | |
output_bottleneck: if True, the cell bottleneck will be concatenated to | |
the cell output. | |
pre_bottleneck: if True, cell assumes that bottlenecking was performing | |
before the function was called. | |
visualize_gates: if True, add histogram summaries of all gates and outputs | |
to tensorboard. | |
""" | |
self._filter_size = list(filter_size) | |
self._output_size = list(output_size) | |
self._num_units = num_units | |
self._forget_bias = forget_bias | |
self._activation = activation | |
self._viz_gates = visualize_gates | |
self._flatten_state = flatten_state | |
self._clip_state = clip_state | |
self._output_bottleneck = output_bottleneck | |
self._pre_bottleneck = pre_bottleneck | |
self._param_count = self._num_units | |
for dim in self._output_size: | |
self._param_count *= dim | |
def state_size(self): | |
return contrib_rnn.LSTMStateTuple(self._output_size + [self._num_units], | |
self._output_size + [self._num_units]) | |
def state_size_flat(self): | |
return contrib_rnn.LSTMStateTuple([self._param_count], [self._param_count]) | |
def output_size(self): | |
return self._output_size + [self._num_units] | |
def __call__(self, inputs, state, scope=None): | |
"""Long short-term memory cell (LSTM) with bottlenecking. | |
Args: | |
inputs: Input tensor at the current timestep. | |
state: Tuple of tensors, the state and output at the previous timestep. | |
scope: Optional scope. | |
Returns: | |
A tuple where the first element is the LSTM output and the second is | |
a LSTMStateTuple of the state at the current timestep. | |
""" | |
scope = scope or 'conv_lstm_cell' | |
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): | |
c, h = state | |
# unflatten state if necessary | |
if self._flatten_state: | |
c = tf.reshape(c, [-1] + self.output_size) | |
h = tf.reshape(h, [-1] + self.output_size) | |
# summary of input passed into cell | |
if self._viz_gates: | |
slim.summaries.add_histogram_summary(inputs, 'cell_input') | |
if self._pre_bottleneck: | |
bottleneck = inputs | |
else: | |
bottleneck = slim.separable_conv2d( | |
tf.concat([inputs, h], 3), | |
self._num_units, | |
self._filter_size, | |
depth_multiplier=1, | |
activation_fn=self._activation, | |
normalizer_fn=None, | |
scope='bottleneck') | |
if self._viz_gates: | |
slim.summaries.add_histogram_summary(bottleneck, 'bottleneck') | |
concat = slim.separable_conv2d( | |
bottleneck, | |
4 * self._num_units, | |
self._filter_size, | |
depth_multiplier=1, | |
activation_fn=None, | |
normalizer_fn=None, | |
scope='gates') | |
i, j, f, o = tf.split(concat, 4, 3) | |
new_c = ( | |
c * tf.sigmoid(f + self._forget_bias) + | |
tf.sigmoid(i) * self._activation(j)) | |
if self._clip_state: | |
new_c = tf.clip_by_value(new_c, -6, 6) | |
new_h = self._activation(new_c) * tf.sigmoid(o) | |
# summary of cell output and new state | |
if self._viz_gates: | |
slim.summaries.add_histogram_summary(new_h, 'cell_output') | |
slim.summaries.add_histogram_summary(new_c, 'cell_state') | |
output = new_h | |
if self._output_bottleneck: | |
output = tf.concat([new_h, bottleneck], axis=3) | |
# reflatten state to store it | |
if self._flatten_state: | |
new_c = tf.reshape(new_c, [-1, self._param_count]) | |
new_h = tf.reshape(new_h, [-1, self._param_count]) | |
return output, contrib_rnn.LSTMStateTuple(new_c, new_h) | |
def init_state(self, state_name, batch_size, dtype, learned_state=False): | |
"""Creates an initial state compatible with this cell. | |
Args: | |
state_name: name of the state tensor | |
batch_size: model batch size | |
dtype: dtype for the tensor values i.e. tf.float32 | |
learned_state: whether the initial state should be learnable. If false, | |
the initial state is set to all 0's | |
Returns: | |
The created initial state. | |
""" | |
state_size = ( | |
self.state_size_flat if self._flatten_state else self.state_size) | |
# list of 2 zero tensors or variables tensors, depending on if | |
# learned_state is true | |
# pylint: disable=g-long-ternary,g-complex-comprehension | |
ret_flat = [(contrib_variables.model_variable( | |
state_name + str(i), | |
shape=s, | |
dtype=dtype, | |
initializer=tf.truncated_normal_initializer(stddev=0.03)) | |
if learned_state else tf.zeros( | |
[batch_size] + s, dtype=dtype, name=state_name)) | |
for i, s in enumerate(state_size)] | |
# duplicates initial state across the batch axis if it's learned | |
if learned_state: | |
ret_flat = [ | |
tf.stack([tensor | |
for i in range(int(batch_size))]) | |
for tensor in ret_flat | |
] | |
for s, r in zip(state_size, ret_flat): | |
r.set_shape([None] + s) | |
return tf.nest.pack_sequence_as(structure=[1, 1], flat_sequence=ret_flat) | |
def pre_bottleneck(self, inputs, state, input_index): | |
"""Apply pre-bottleneck projection to inputs. | |
Pre-bottleneck operation maps features of different channels into the same | |
dimension. The purpose of this op is to share the features from both large | |
and small models in the same LSTM cell. | |
Args: | |
inputs: 4D Tensor with shape [batch_size x width x height x input_size]. | |
state: 4D Tensor with shape [batch_size x width x height x state_size]. | |
input_index: integer index indicating which base features the inputs | |
correspoding to. | |
Returns: | |
inputs: pre-bottlenecked inputs. | |
Raises: | |
ValueError: If pre_bottleneck is not set or inputs is not rank 4. | |
""" | |
# Sometimes state is a tuple, in which case it cannot be modified, e.g. | |
# during training, tf.contrib.training.SequenceQueueingStateSaver | |
# returns the state as a tuple. This should not be an issue since we | |
# only need to modify state[1] during export, when state should be a | |
# list. | |
if len(inputs.shape) != 4: | |
raise ValueError('Expect rank 4 feature tensor.') | |
if not self._flatten_state and len(state.shape) != 4: | |
raise ValueError('Expect rank 4 state tensor.') | |
if self._flatten_state and len(state.shape) != 2: | |
raise ValueError('Expect rank 2 state tensor when flatten_state is set.') | |
with tf.name_scope(None): | |
state = tf.identity(state, name='raw_inputs/init_lstm_h') | |
if self._flatten_state: | |
batch_size = inputs.shape[0] | |
height = inputs.shape[1] | |
width = inputs.shape[2] | |
state = tf.reshape(state, [batch_size, height, width, -1]) | |
with tf.variable_scope('conv_lstm_cell', reuse=tf.AUTO_REUSE): | |
scope_name = 'bottleneck_%d' % input_index | |
inputs = slim.separable_conv2d( | |
tf.concat([inputs, state], 3), | |
self.output_size[-1], | |
self._filter_size, | |
depth_multiplier=1, | |
activation_fn=tf.nn.relu6, | |
normalizer_fn=None, | |
scope=scope_name) | |
# For exporting inference graph, we only mark the first timestep. | |
with tf.name_scope(None): | |
inputs = tf.identity( | |
inputs, name='raw_outputs/base_endpoint_%d' % (input_index + 1)) | |
return inputs | |
class GroupedConvLSTMCell(contrib_rnn.RNNCell): | |
"""Basic LSTM recurrent network cell using separable convolutions. | |
The implementation is based on: https://arxiv.org/abs/1903.10172. | |
We add forget_bias (default: 1) to the biases of the forget gate in order to | |
reduce the scale of forgetting in the beginning of the training. | |
This LSTM first projects inputs to the size of the output before doing gate | |
computations. This saves params unless the input is less than a third of the | |
state size channel-wise. Computation of bottlenecks and gates is divided | |
into independent groups for further savings. | |
""" | |
def __init__(self, | |
filter_size, | |
output_size, | |
num_units, | |
is_training, | |
forget_bias=1.0, | |
activation=tf.tanh, | |
use_batch_norm=False, | |
flatten_state=False, | |
groups=4, | |
clip_state=False, | |
scale_state=False, | |
output_bottleneck=False, | |
pre_bottleneck=False, | |
is_quantized=False, | |
visualize_gates=False, | |
conv_op_overrides=None): | |
"""Initialize the basic LSTM cell. | |
Args: | |
filter_size: collection, conv filter size | |
output_size: collection, the width/height dimensions of the cell/output | |
num_units: int, The number of channels in the LSTM cell. | |
is_training: Whether the LSTM is in training mode. | |
forget_bias: float, The bias added to forget gates (see above). | |
activation: Activation function of the inner states. | |
use_batch_norm: if True, use batch norm after convolution | |
flatten_state: if True, state tensor will be flattened and stored as a 2-d | |
tensor. Use for exporting the model to tfmini | |
groups: Number of groups to split the state into. Must evenly divide | |
num_units. | |
clip_state: if True, clips state between [-6, 6]. | |
scale_state: if True, scales state so that all values are under 6 at all | |
times. | |
output_bottleneck: if True, the cell bottleneck will be concatenated to | |
the cell output. | |
pre_bottleneck: if True, cell assumes that bottlenecking was performing | |
before the function was called. | |
is_quantized: if True, the model is in quantize mode, which requires | |
quantization friendly concat and separable_conv2d ops. | |
visualize_gates: if True, add histogram summaries of all gates and outputs | |
to tensorboard | |
conv_op_overrides: A list of convolutional operations that override the | |
'bottleneck' and 'convolution' layers before lstm gates. If None, the | |
original implementation of seperable_conv will be used. The length of | |
the list should be two. | |
Raises: | |
ValueError: when both clip_state and scale_state are enabled. | |
""" | |
if clip_state and scale_state: | |
raise ValueError('clip_state and scale_state cannot both be enabled.') | |
self._filter_size = list(filter_size) | |
self._output_size = list(output_size) | |
self._num_units = num_units | |
self._is_training = is_training | |
self._forget_bias = forget_bias | |
self._activation = activation | |
self._use_batch_norm = use_batch_norm | |
self._viz_gates = visualize_gates | |
self._flatten_state = flatten_state | |
self._param_count = self._num_units | |
self._groups = groups | |
self._scale_state = scale_state | |
self._clip_state = clip_state | |
self._output_bottleneck = output_bottleneck | |
self._pre_bottleneck = pre_bottleneck | |
self._is_quantized = is_quantized | |
for dim in self._output_size: | |
self._param_count *= dim | |
self._conv_op_overrides = conv_op_overrides | |
if self._conv_op_overrides and len(self._conv_op_overrides) != 2: | |
raise ValueError('Bottleneck and Convolutional layer should be overriden' | |
'together') | |
def state_size(self): | |
return contrib_rnn.LSTMStateTuple(self._output_size + [self._num_units], | |
self._output_size + [self._num_units]) | |
def state_size_flat(self): | |
return contrib_rnn.LSTMStateTuple([self._param_count], [self._param_count]) | |
def output_size(self): | |
return self._output_size + [self._num_units] | |
def filter_size(self): | |
return self._filter_size | |
def num_groups(self): | |
return self._groups | |
def __call__(self, inputs, state, scope=None): | |
"""Long short-term memory cell (LSTM) with bottlenecking. | |
Includes logic for quantization-aware training. Note that all concats and | |
activations use fixed ranges unless stated otherwise. | |
Args: | |
inputs: Input tensor at the current timestep. | |
state: Tuple of tensors, the state at the previous timestep. | |
scope: Optional scope. | |
Returns: | |
A tuple where the first element is the LSTM output and the second is | |
a LSTMStateTuple of the state at the current timestep. | |
""" | |
scope = scope or 'conv_lstm_cell' | |
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): | |
c, h = state | |
# Set nodes to be under raw_inputs/ name scope for tfmini export. | |
with tf.name_scope(None): | |
c = tf.identity(c, name='raw_inputs/init_lstm_c') | |
# When pre_bottleneck is enabled, input h handle is in rnn_decoder.py | |
if not self._pre_bottleneck: | |
h = tf.identity(h, name='raw_inputs/init_lstm_h') | |
# unflatten state if necessary | |
if self._flatten_state: | |
c = tf.reshape(c, [-1] + self.output_size) | |
h = tf.reshape(h, [-1] + self.output_size) | |
c_list = tf.split(c, self._groups, axis=3) | |
if self._pre_bottleneck: | |
inputs_list = tf.split(inputs, self._groups, axis=3) | |
else: | |
h_list = tf.split(h, self._groups, axis=3) | |
out_bottleneck = [] | |
out_c = [] | |
out_h = [] | |
# summary of input passed into cell | |
if self._viz_gates: | |
slim.summaries.add_histogram_summary(inputs, 'cell_input') | |
for k in range(self._groups): | |
if self._pre_bottleneck: | |
bottleneck = inputs_list[k] | |
else: | |
if self._conv_op_overrides: | |
bottleneck_fn = self._conv_op_overrides[0] | |
else: | |
bottleneck_fn = functools.partial( | |
lstm_utils.quantizable_separable_conv2d, | |
kernel_size=self._filter_size, | |
activation_fn=self._activation) | |
if self._use_batch_norm: | |
b_x = bottleneck_fn( | |
inputs=inputs, | |
num_outputs=self._num_units // self._groups, | |
is_quantized=self._is_quantized, | |
depth_multiplier=1, | |
normalizer_fn=None, | |
scope='bottleneck_%d_x' % k) | |
b_h = bottleneck_fn( | |
inputs=h_list[k], | |
num_outputs=self._num_units // self._groups, | |
is_quantized=self._is_quantized, | |
depth_multiplier=1, | |
normalizer_fn=None, | |
scope='bottleneck_%d_h' % k) | |
b_x = slim.batch_norm( | |
b_x, | |
scale=True, | |
is_training=self._is_training, | |
scope='BatchNorm_%d_X' % k) | |
b_h = slim.batch_norm( | |
b_h, | |
scale=True, | |
is_training=self._is_training, | |
scope='BatchNorm_%d_H' % k) | |
bottleneck = b_x + b_h | |
else: | |
# All concats use fixed quantization ranges to prevent rescaling | |
# at inference. Both |inputs| and |h_list| are tensors resulting | |
# from Relu6 operations so we fix the ranges to [0, 6]. | |
bottleneck_concat = lstm_utils.quantizable_concat( | |
[inputs, h_list[k]], | |
axis=3, | |
is_training=False, | |
is_quantized=self._is_quantized, | |
scope='bottleneck_%d/quantized_concat' % k) | |
bottleneck = bottleneck_fn( | |
inputs=bottleneck_concat, | |
num_outputs=self._num_units // self._groups, | |
is_quantized=self._is_quantized, | |
depth_multiplier=1, | |
normalizer_fn=None, | |
scope='bottleneck_%d' % k) | |
if self._conv_op_overrides: | |
conv_fn = self._conv_op_overrides[1] | |
else: | |
conv_fn = functools.partial( | |
lstm_utils.quantizable_separable_conv2d, | |
kernel_size=self._filter_size, | |
activation_fn=None) | |
concat = conv_fn( | |
inputs=bottleneck, | |
num_outputs=4 * self._num_units // self._groups, | |
is_quantized=self._is_quantized, | |
depth_multiplier=1, | |
normalizer_fn=None, | |
scope='concat_conv_%d' % k) | |
# Since there is no activation in the previous separable conv, we | |
# quantize here. A starting range of [-6, 6] is used because the | |
# tensors are input to a Sigmoid function that saturates at these | |
# ranges. | |
concat = lstm_utils.quantize_op( | |
concat, | |
is_training=self._is_training, | |
default_min=-6, | |
default_max=6, | |
is_quantized=self._is_quantized, | |
scope='gates_%d/act_quant' % k) | |
# i = input_gate, j = new_input, f = forget_gate, o = output_gate | |
i, j, f, o = tf.split(concat, 4, 3) | |
f_add = f + self._forget_bias | |
f_add = lstm_utils.quantize_op( | |
f_add, | |
is_training=self._is_training, | |
default_min=-6, | |
default_max=6, | |
is_quantized=self._is_quantized, | |
scope='forget_gate_%d/add_quant' % k) | |
f_act = tf.sigmoid(f_add) | |
a = c_list[k] * f_act | |
a = lstm_utils.quantize_op( | |
a, | |
is_training=self._is_training, | |
is_quantized=self._is_quantized, | |
scope='forget_gate_%d/mul_quant' % k) | |
i_act = tf.sigmoid(i) | |
j_act = self._activation(j) | |
# The quantization range is fixed for the relu6 to ensure that zero | |
# is exactly representable. | |
j_act = lstm_utils.fixed_quantize_op( | |
j_act, | |
fixed_min=0.0, | |
fixed_max=6.0, | |
is_quantized=self._is_quantized, | |
scope='new_input_%d/act_quant' % k) | |
b = i_act * j_act | |
b = lstm_utils.quantize_op( | |
b, | |
is_training=self._is_training, | |
is_quantized=self._is_quantized, | |
scope='input_gate_%d/mul_quant' % k) | |
new_c = a + b | |
# The quantization range is fixed to [0, 6] due to an optimization in | |
# TFLite. The order of operations is as fllows: | |
# Add -> FakeQuant -> Relu6 -> FakeQuant -> Concat. | |
# The fakequant ranges to the concat must be fixed to ensure all inputs | |
# to the concat have the same range, removing the need for rescaling. | |
# The quantization ranges input to the relu6 are propagated to its | |
# output. Any mismatch between these two ranges will cause an error. | |
new_c = lstm_utils.fixed_quantize_op( | |
new_c, | |
fixed_min=0.0, | |
fixed_max=6.0, | |
is_quantized=self._is_quantized, | |
scope='new_c_%d/add_quant' % k) | |
if not self._is_quantized: | |
if self._scale_state: | |
normalizer = tf.maximum(1.0, | |
tf.reduce_max(new_c, axis=(1, 2, 3)) / 6) | |
new_c /= tf.reshape(normalizer, [tf.shape(new_c)[0], 1, 1, 1]) | |
elif self._clip_state: | |
new_c = tf.clip_by_value(new_c, -6, 6) | |
new_c_act = self._activation(new_c) | |
# The quantization range is fixed for the relu6 to ensure that zero | |
# is exactly representable. | |
new_c_act = lstm_utils.fixed_quantize_op( | |
new_c_act, | |
fixed_min=0.0, | |
fixed_max=6.0, | |
is_quantized=self._is_quantized, | |
scope='new_c_%d/act_quant' % k) | |
o_act = tf.sigmoid(o) | |
new_h = new_c_act * o_act | |
# The quantization range is fixed since it is input to a concat. | |
# A range of [0, 6] is used since |new_h| is a product of ranges [0, 6] | |
# and [0, 1]. | |
new_h_act = lstm_utils.fixed_quantize_op( | |
new_h, | |
fixed_min=0.0, | |
fixed_max=6.0, | |
is_quantized=self._is_quantized, | |
scope='new_h_%d/act_quant' % k) | |
out_bottleneck.append(bottleneck) | |
out_c.append(new_c_act) | |
out_h.append(new_h_act) | |
# Since all inputs to the below concats are already quantized, we can use | |
# a regular concat operation. | |
new_c = tf.concat(out_c, axis=3) | |
new_h = tf.concat(out_h, axis=3) | |
# |bottleneck| is input to a concat with |new_h|. We must use | |
# quantizable_concat() with a fixed range that matches |new_h|. | |
bottleneck = lstm_utils.quantizable_concat( | |
out_bottleneck, | |
axis=3, | |
is_training=False, | |
is_quantized=self._is_quantized, | |
scope='out_bottleneck/quantized_concat') | |
# summary of cell output and new state | |
if self._viz_gates: | |
slim.summaries.add_histogram_summary(new_h, 'cell_output') | |
slim.summaries.add_histogram_summary(new_c, 'cell_state') | |
output = new_h | |
if self._output_bottleneck: | |
output = lstm_utils.quantizable_concat( | |
[new_h, bottleneck], | |
axis=3, | |
is_training=False, | |
is_quantized=self._is_quantized, | |
scope='new_output/quantized_concat') | |
# reflatten state to store it | |
if self._flatten_state: | |
new_c = tf.reshape(new_c, [-1, self._param_count], name='lstm_c') | |
new_h = tf.reshape(new_h, [-1, self._param_count], name='lstm_h') | |
# Set nodes to be under raw_outputs/ name scope for tfmini export. | |
with tf.name_scope(None): | |
new_c = tf.identity(new_c, name='raw_outputs/lstm_c') | |
new_h = tf.identity(new_h, name='raw_outputs/lstm_h') | |
states_and_output = contrib_rnn.LSTMStateTuple(new_c, new_h) | |
return output, states_and_output | |
def init_state(self, state_name, batch_size, dtype, learned_state=False): | |
"""Creates an initial state compatible with this cell. | |
Args: | |
state_name: name of the state tensor | |
batch_size: model batch size | |
dtype: dtype for the tensor values i.e. tf.float32 | |
learned_state: whether the initial state should be learnable. If false, | |
the initial state is set to all 0's | |
Returns: | |
ret: the created initial state | |
""" | |
state_size = ( | |
self.state_size_flat if self._flatten_state else self.state_size) | |
# list of 2 zero tensors or variables tensors, | |
# depending on if learned_state is true | |
# pylint: disable=g-long-ternary,g-complex-comprehension | |
ret_flat = [(contrib_variables.model_variable( | |
state_name + str(i), | |
shape=s, | |
dtype=dtype, | |
initializer=tf.truncated_normal_initializer(stddev=0.03)) | |
if learned_state else tf.zeros( | |
[batch_size] + s, dtype=dtype, name=state_name)) | |
for i, s in enumerate(state_size)] | |
# duplicates initial state across the batch axis if it's learned | |
if learned_state: | |
ret_flat = [tf.stack([tensor for i in range(int(batch_size))]) | |
for tensor in ret_flat] | |
for s, r in zip(state_size, ret_flat): | |
r = tf.reshape(r, [-1] + s) | |
ret = tf.nest.pack_sequence_as(structure=[1, 1], flat_sequence=ret_flat) | |
return ret | |
def pre_bottleneck(self, inputs, state, input_index): | |
"""Apply pre-bottleneck projection to inputs. | |
Pre-bottleneck operation maps features of different channels into the same | |
dimension. The purpose of this op is to share the features from both large | |
and small models in the same LSTM cell. | |
Args: | |
inputs: 4D Tensor with shape [batch_size x width x height x input_size]. | |
state: 4D Tensor with shape [batch_size x width x height x state_size]. | |
input_index: integer index indicating which base features the inputs | |
correspoding to. | |
Returns: | |
inputs: pre-bottlenecked inputs. | |
Raises: | |
ValueError: If pre_bottleneck is not set or inputs is not rank 4. | |
""" | |
# Sometimes state is a tuple, in which case it cannot be modified, e.g. | |
# during training, tf.contrib.training.SequenceQueueingStateSaver | |
# returns the state as a tuple. This should not be an issue since we | |
# only need to modify state[1] during export, when state should be a | |
# list. | |
if not self._pre_bottleneck: | |
raise ValueError('Only applied when pre_bottleneck is set to true.') | |
if len(inputs.shape) != 4: | |
raise ValueError('Expect a rank 4 feature tensor.') | |
if not self._flatten_state and len(state.shape) != 4: | |
raise ValueError('Expect rank 4 state tensor.') | |
if self._flatten_state and len(state.shape) != 2: | |
raise ValueError('Expect rank 2 state tensor when flatten_state is set.') | |
with tf.name_scope(None): | |
state = tf.identity( | |
state, name='raw_inputs/init_lstm_h_%d' % (input_index + 1)) | |
if self._flatten_state: | |
batch_size = inputs.shape[0] | |
height = inputs.shape[1] | |
width = inputs.shape[2] | |
state = tf.reshape(state, [batch_size, height, width, -1]) | |
with tf.variable_scope('conv_lstm_cell', reuse=tf.AUTO_REUSE): | |
state_split = tf.split(state, self._groups, axis=3) | |
with tf.variable_scope('bottleneck_%d' % input_index): | |
bottleneck_out = [] | |
for k in range(self._groups): | |
with tf.variable_scope('group_%d' % k): | |
bottleneck_out.append( | |
lstm_utils.quantizable_separable_conv2d( | |
lstm_utils.quantizable_concat( | |
[inputs, state_split[k]], | |
axis=3, | |
is_training=self._is_training, | |
is_quantized=self._is_quantized, | |
scope='quantized_concat'), | |
self.output_size[-1] / self._groups, | |
self._filter_size, | |
is_quantized=self._is_quantized, | |
depth_multiplier=1, | |
activation_fn=tf.nn.relu6, | |
normalizer_fn=None, | |
scope='project')) | |
inputs = lstm_utils.quantizable_concat( | |
bottleneck_out, | |
axis=3, | |
is_training=self._is_training, | |
is_quantized=self._is_quantized, | |
scope='bottleneck_out/quantized_concat') | |
# For exporting inference graph, we only mark the first timestep. | |
with tf.name_scope(None): | |
inputs = tf.identity( | |
inputs, name='raw_outputs/base_endpoint_%d' % (input_index + 1)) | |
return inputs | |