File size: 3,804 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for LSTM tensorflow blocks."""
from __future__ import division

import numpy as np
import tensorflow as tf

import block_base
import blocks_std
import blocks_lstm


class BlocksLSTMTest(tf.test.TestCase):

  def CheckUnary(self, y, op_type):
    self.assertEqual(op_type, y.op.type)
    self.assertEqual(1, len(y.op.inputs))
    return y.op.inputs[0]

  def CheckBinary(self, y, op_type):
    self.assertEqual(op_type, y.op.type)
    self.assertEqual(2, len(y.op.inputs))
    return y.op.inputs

  def testLSTM(self):
    lstm = blocks_lstm.LSTM(10)
    lstm.hidden = tf.zeros(shape=[10, 10], dtype=tf.float32)
    lstm.cell = tf.zeros(shape=[10, 10], dtype=tf.float32)
    x = tf.placeholder(dtype=tf.float32, shape=[10, 11])
    y = lstm(x)

    o, tanhc = self.CheckBinary(y, 'Mul')
    self.assertEqual(self.CheckUnary(o, 'Sigmoid').name, 'LSTM/split:3')

    self.assertIs(lstm.cell, self.CheckUnary(tanhc, 'Tanh'))
    fc, ij = self.CheckBinary(lstm.cell, 'Add')

    f, _ = self.CheckBinary(fc, 'Mul')
    self.assertEqual(self.CheckUnary(f, 'Sigmoid').name, 'LSTM/split:0')

    i, j = self.CheckBinary(ij, 'Mul')
    self.assertEqual(self.CheckUnary(i, 'Sigmoid').name, 'LSTM/split:1')
    j = self.CheckUnary(j, 'Tanh')
    self.assertEqual(j.name, 'LSTM/split:2')

  def testLSTMBiasInit(self):
    lstm = blocks_lstm.LSTM(9)
    x = tf.placeholder(dtype=tf.float32, shape=[15, 7])
    lstm(x)
    b = lstm._nn._bias

    with self.test_session():
      tf.global_variables_initializer().run()
      bias_var = b._bias.eval()

      comp = ([1.0] * 9) + ([0.0] * 27)
      self.assertAllEqual(bias_var, comp)

  def testConv2DLSTM(self):
    lstm = blocks_lstm.Conv2DLSTM(depth=10,
                                  filter_size=[1, 1],
                                  hidden_filter_size=[1, 1],
                                  strides=[1, 1],
                                  padding='SAME')
    lstm.hidden = tf.zeros(shape=[10, 11, 11, 10], dtype=tf.float32)
    lstm.cell = tf.zeros(shape=[10, 11, 11, 10], dtype=tf.float32)
    x = tf.placeholder(dtype=tf.float32, shape=[10, 11, 11, 1])
    y = lstm(x)

    o, tanhc = self.CheckBinary(y, 'Mul')
    self.assertEqual(self.CheckUnary(o, 'Sigmoid').name, 'Conv2DLSTM/split:3')

    self.assertIs(lstm.cell, self.CheckUnary(tanhc, 'Tanh'))
    fc, ij = self.CheckBinary(lstm.cell, 'Add')

    f, _ = self.CheckBinary(fc, 'Mul')
    self.assertEqual(self.CheckUnary(f, 'Sigmoid').name, 'Conv2DLSTM/split:0')

    i, j = self.CheckBinary(ij, 'Mul')
    self.assertEqual(self.CheckUnary(i, 'Sigmoid').name, 'Conv2DLSTM/split:1')
    j = self.CheckUnary(j, 'Tanh')
    self.assertEqual(j.name, 'Conv2DLSTM/split:2')

  def testConv2DLSTMBiasInit(self):
    lstm = blocks_lstm.Conv2DLSTM(9, 1, 1, [1, 1], 'SAME')
    x = tf.placeholder(dtype=tf.float32, shape=[1, 7, 7, 7])
    lstm(x)
    b = lstm._bias

    with self.test_session():
      tf.global_variables_initializer().run()
      bias_var = b._bias.eval()

      comp = ([1.0] * 9) + ([0.0] * 27)
      self.assertAllEqual(bias_var, comp)


if __name__ == '__main__':
  tf.test.main()