NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Custom RNN decoder."""
import tensorflow.compat.v1 as tf
import lstm_object_detection.lstm.utils as lstm_utils
class _NoVariableScope(object):
def __enter__(self):
return
def __exit__(self, exc_type, exc_value, traceback):
return False
def rnn_decoder(decoder_inputs,
initial_state,
cell,
loop_function=None,
scope=None):
"""RNN decoder for the LSTM-SSD model.
This decoder returns a list of all states, rather than only the final state.
Args:
decoder_inputs: A list of 4D Tensors with shape [batch_size x input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
cell: rnn_cell.RNNCell defining the cell function and size.
loop_function: If not None, this function will be applied to the i-th output
in order to generate the i+1-st input, and decoder_inputs will be ignored,
except for the first element ("GO" symbol). This can be used for decoding,
but also for training to emulate http://arxiv.org/abs/1506.03099.
Signature -- loop_function(prev, i) = next
* prev is a 2D Tensor of shape [batch_size x output_size],
* i is an integer, the step number (when advanced control is needed),
* next is a 2D Tensor of shape [batch_size x input_size].
scope: optional VariableScope for the created subgraph.
Returns:
A tuple of the form (outputs, state), where:
outputs: A list of the same length as decoder_inputs of 4D Tensors with
shape [batch_size x output_size] containing generated outputs.
states: A list of the same length as decoder_inputs of the state of each
cell at each time-step. It is a 2D Tensor of shape
[batch_size x cell.state_size].
"""
with tf.variable_scope(scope) if scope else _NoVariableScope():
state_tuple = initial_state
outputs = []
states = []
prev = None
for local_step, decoder_input in enumerate(decoder_inputs):
if loop_function is not None and prev is not None:
with tf.variable_scope('loop_function', reuse=True):
decoder_input = loop_function(prev, local_step)
output, state_tuple = cell(decoder_input, state_tuple)
outputs.append(output)
states.append(state_tuple)
if loop_function is not None:
prev = output
return outputs, states
def multi_input_rnn_decoder(decoder_inputs,
initial_state,
cell,
sequence_step,
selection_strategy='RANDOM',
is_training=None,
is_quantized=False,
preprocess_fn_list=None,
pre_bottleneck=False,
flatten_state=False,
scope=None):
"""RNN decoder for the Interleaved LSTM-SSD model.
This decoder takes multiple sequences of inputs and selects the input to feed
to the rnn at each timestep using its selection_strategy, which can be random,
learned, or deterministic.
This decoder returns a list of all states, rather than only the final state.
Args:
decoder_inputs: A list of lists of 2D Tensors [batch_size x input_size].
initial_state: 2D Tensor with shape [batch_size x cell.state_size].
cell: rnn_cell.RNNCell defining the cell function and size.
sequence_step: Tensor [batch_size] of the step number of the first elements
in the sequence.
selection_strategy: Method for picking the decoder_input to use at each
timestep. Must be 'RANDOM', 'SKIPX' for integer X, where X is the number
of times to use the second input before using the first.
is_training: boolean, whether the network is training. When using learned
selection, attempts exploration if training.
is_quantized: flag to enable/disable quantization mode.
preprocess_fn_list: List of functions accepting two tensor arguments: one
timestep of decoder_inputs and the lstm state. If not None,
decoder_inputs[i] will be updated with preprocess_fn[i] at the start of
each timestep.
pre_bottleneck: if True, use separate bottleneck weights for each sequence.
Useful when input sequences have differing numbers of channels. Final
bottlenecks will have the same dimension.
flatten_state: Whether the LSTM state is flattened.
scope: optional VariableScope for the created subgraph.
Returns:
A tuple of the form (outputs, state), where:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x output_size] containing generated outputs.
states: A list of the same length as decoder_inputs of the state of each
cell at each time-step. It is a 2D Tensor of shape
[batch_size x cell.state_size].
Raises:
ValueError: If selection_strategy is not recognized or unexpected unroll
length.
"""
if flatten_state and len(decoder_inputs[0]) > 1:
raise ValueError('In export mode, unroll length should not be more than 1')
with tf.variable_scope(scope) if scope else _NoVariableScope():
state_tuple = initial_state
outputs = []
states = []
batch_size = decoder_inputs[0][0].shape[0].value
num_sequences = len(decoder_inputs)
sequence_length = len(decoder_inputs[0])
for local_step in range(sequence_length):
for sequence_index in range(num_sequences):
if preprocess_fn_list is not None:
decoder_inputs[sequence_index][local_step] = (
preprocess_fn_list[sequence_index](
decoder_inputs[sequence_index][local_step], state_tuple[0]))
if pre_bottleneck:
decoder_inputs[sequence_index][local_step] = cell.pre_bottleneck(
inputs=decoder_inputs[sequence_index][local_step],
state=state_tuple[1],
input_index=sequence_index)
action = generate_action(selection_strategy, local_step, sequence_step,
[batch_size, 1, 1, 1])
inputs, _ = (
select_inputs(decoder_inputs, action, local_step, is_training,
is_quantized))
# Mark base network endpoints under raw_inputs/
with tf.name_scope(None):
inputs = tf.identity(inputs, 'raw_inputs/base_endpoint')
output, state_tuple_out = cell(inputs, state_tuple)
state_tuple = select_state(state_tuple, state_tuple_out, action)
outputs.append(output)
states.append(state_tuple)
return outputs, states
def generate_action(selection_strategy, local_step, sequence_step,
action_shape):
"""Generate current (binary) action based on selection strategy.
Args:
selection_strategy: Method for picking the decoder_input to use at each
timestep. Must be 'RANDOM', 'SKIPX' for integer X, where X is the number
of times to use the second input before using the first.
local_step: Tensor [batch_size] of the step number within the current
unrolled batch.
sequence_step: Tensor [batch_size] of the step number of the first elements
in the sequence.
action_shape: The shape of action tensor to be generated.
Returns:
A tensor of shape action_shape, each element is an individual action.
Raises:
ValueError: if selection_strategy is not supported or if 'SKIP' is not
followed by numerics.
"""
if selection_strategy.startswith('RANDOM'):
action = tf.random.uniform(action_shape, maxval=2, dtype=tf.int32)
action = tf.minimum(action, 1)
# First step always runs large network.
if local_step == 0 and sequence_step is not None:
action *= tf.minimum(
tf.reshape(tf.cast(sequence_step, tf.int32), action_shape), 1)
elif selection_strategy.startswith('SKIP'):
inter_count = int(selection_strategy[4:])
if local_step % (inter_count + 1) == 0:
action = tf.zeros(action_shape)
else:
action = tf.ones(action_shape)
else:
raise ValueError('Selection strategy %s not recognized' %
selection_strategy)
return tf.cast(action, tf.int32)
def select_inputs(decoder_inputs, action, local_step, is_training, is_quantized,
get_alt_inputs=False):
"""Selects sequence from decoder_inputs based on 1D actions.
Given multiple input batches, creates a single output batch by
selecting from the action[i]-ith input for the i-th batch element.
Args:
decoder_inputs: A 2-D list of tensor inputs.
action: A tensor of shape [batch_size]. Each element corresponds to an index
of decoder_inputs to choose.
local_step: The current timestep.
is_training: boolean, whether the network is training. When using learned
selection, attempts exploration if training.
is_quantized: flag to enable/disable quantization mode.
get_alt_inputs: Whether the non-chosen inputs should also be returned.
Returns:
The constructed output. Also outputs the elements that were not chosen
if get_alt_inputs is True, otherwise None.
Raises:
ValueError: if the decoder inputs contains other than two sequences.
"""
num_seqs = len(decoder_inputs)
if not num_seqs == 2:
raise ValueError('Currently only supports two sets of inputs.')
stacked_inputs = tf.stack(
[decoder_inputs[seq_index][local_step] for seq_index in range(num_seqs)],
axis=-1)
action_index = tf.one_hot(action, num_seqs)
selected_inputs = (
lstm_utils.quantize_op(stacked_inputs * action_index, is_training,
is_quantized, scope='quant_selected_inputs'))
inputs = tf.reduce_sum(selected_inputs, axis=-1)
inputs_alt = None
# Only works for 2 models.
if get_alt_inputs:
# Reverse of action_index.
action_index_alt = tf.one_hot(action, num_seqs, on_value=0.0, off_value=1.0)
selected_inputs = (
lstm_utils.quantize_op(stacked_inputs * action_index_alt, is_training,
is_quantized, scope='quant_selected_inputs_alt'))
inputs_alt = tf.reduce_sum(selected_inputs, axis=-1)
return inputs, inputs_alt
def select_state(previous_state, new_state, action):
"""Select state given action.
Currently only supports binary action. If action is 0, it means the state is
generated from the large model, and thus we will update the state. Otherwise,
if the action is 1, it means the state is generated from the small model, and
in interleaved model, we skip this state update.
Args:
previous_state: A state tuple representing state from previous step.
new_state: A state tuple representing newly computed state.
action: A tensor the same shape as state.
Returns:
A state tuple selected based on the given action.
"""
action = tf.cast(action, tf.float32)
state_c = previous_state[0] * action + new_state[0] * (1 - action)
state_h = previous_state[1] * action + new_state[1] * (1 - action)
return (state_c, state_h)