# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for LSTM tensorflow blocks.""" from __future__ import division import numpy as np import tensorflow as tf import block_base import blocks_std import blocks_lstm class BlocksLSTMTest(tf.test.TestCase): def CheckUnary(self, y, op_type): self.assertEqual(op_type, y.op.type) self.assertEqual(1, len(y.op.inputs)) return y.op.inputs[0] def CheckBinary(self, y, op_type): self.assertEqual(op_type, y.op.type) self.assertEqual(2, len(y.op.inputs)) return y.op.inputs def testLSTM(self): lstm = blocks_lstm.LSTM(10) lstm.hidden = tf.zeros(shape=[10, 10], dtype=tf.float32) lstm.cell = tf.zeros(shape=[10, 10], dtype=tf.float32) x = tf.placeholder(dtype=tf.float32, shape=[10, 11]) y = lstm(x) o, tanhc = self.CheckBinary(y, 'Mul') self.assertEqual(self.CheckUnary(o, 'Sigmoid').name, 'LSTM/split:3') self.assertIs(lstm.cell, self.CheckUnary(tanhc, 'Tanh')) fc, ij = self.CheckBinary(lstm.cell, 'Add') f, _ = self.CheckBinary(fc, 'Mul') self.assertEqual(self.CheckUnary(f, 'Sigmoid').name, 'LSTM/split:0') i, j = self.CheckBinary(ij, 'Mul') self.assertEqual(self.CheckUnary(i, 'Sigmoid').name, 'LSTM/split:1') j = self.CheckUnary(j, 'Tanh') self.assertEqual(j.name, 'LSTM/split:2') def testLSTMBiasInit(self): lstm = blocks_lstm.LSTM(9) x = tf.placeholder(dtype=tf.float32, shape=[15, 7]) lstm(x) b = lstm._nn._bias with self.test_session(): tf.global_variables_initializer().run() bias_var = b._bias.eval() comp = ([1.0] * 9) + ([0.0] * 27) self.assertAllEqual(bias_var, comp) def testConv2DLSTM(self): lstm = blocks_lstm.Conv2DLSTM(depth=10, filter_size=[1, 1], hidden_filter_size=[1, 1], strides=[1, 1], padding='SAME') lstm.hidden = tf.zeros(shape=[10, 11, 11, 10], dtype=tf.float32) lstm.cell = tf.zeros(shape=[10, 11, 11, 10], dtype=tf.float32) x = tf.placeholder(dtype=tf.float32, shape=[10, 11, 11, 1]) y = lstm(x) o, tanhc = self.CheckBinary(y, 'Mul') self.assertEqual(self.CheckUnary(o, 'Sigmoid').name, 'Conv2DLSTM/split:3') self.assertIs(lstm.cell, self.CheckUnary(tanhc, 'Tanh')) fc, ij = self.CheckBinary(lstm.cell, 'Add') f, _ = self.CheckBinary(fc, 'Mul') self.assertEqual(self.CheckUnary(f, 'Sigmoid').name, 'Conv2DLSTM/split:0') i, j = self.CheckBinary(ij, 'Mul') self.assertEqual(self.CheckUnary(i, 'Sigmoid').name, 'Conv2DLSTM/split:1') j = self.CheckUnary(j, 'Tanh') self.assertEqual(j.name, 'Conv2DLSTM/split:2') def testConv2DLSTMBiasInit(self): lstm = blocks_lstm.Conv2DLSTM(9, 1, 1, [1, 1], 'SAME') x = tf.placeholder(dtype=tf.float32, shape=[1, 7, 7, 7]) lstm(x) b = lstm._bias with self.test_session(): tf.global_variables_initializer().run() bias_var = b._bias.eval() comp = ([1.0] * 9) + ([0.0] * 27) self.assertAllEqual(bias_var, comp) if __name__ == '__main__': tf.test.main()