ydin0771 commited on
Commit
86e0937
β€’
1 Parent(s): f07fdc1

Upload mi_lstm_cell.py

Browse files
Files changed (1) hide show
  1. mi_lstm_cell.py +77 -0
mi_lstm_cell.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+
4
+ class MiLSTMCell(tf.nn.rnn_cell.RNNCell):
5
+ def __init__(self, num_units, forget_bias = 1.0, input_size = None,
6
+ state_is_tuple = True, activation = tf.tanh, reuse = None):
7
+ self.numUnits = num_units
8
+ self.forgetBias = forget_bias
9
+ self.activation = activation
10
+ self.reuse = reuse
11
+
12
+ @property
13
+ def state_size(self):
14
+ return tf.nn.rnn_cell.LSTMStateTuple(self.numUnits, self.numUnits)
15
+
16
+ @property
17
+ def output_size(self):
18
+ return self.numUnits
19
+
20
+ def mulWeights(self, inp, inDim, outDim, name = ""):
21
+ with tf.variable_scope("weights" + name):
22
+ W = tf.get_variable("weights", shape = (inDim, outDim),
23
+ initializer = tf.contrib.layers.xavier_initializer())
24
+ output = tf.matmul(inp, W)
25
+ return output
26
+
27
+ def addBiases(self, inp1, inp2, dim, name = ""):
28
+ with tf.variable_scope("additiveBiases" + name):
29
+ b = tf.get_variable("biases", shape = (dim,),
30
+ initializer = tf.zeros_initializer())
31
+ with tf.variable_scope("multiplicativeBias" + name):
32
+ beta = tf.get_variable("biases", shape = (3 * dim,),
33
+ initializer = tf.ones_initializer())
34
+
35
+ Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
36
+ num_or_size_splits = 3, axis = 1)
37
+ output = Wx + Uh + inter + b
38
+ return output
39
+
40
+ def __call__(self, inputs, state, scope = None):
41
+ scope = scope or type(self).__name__
42
+ with tf.variable_scope(scope, reuse = self.reuse):
43
+ c, h = state
44
+ inputSize = int(inputs.shape[1])
45
+
46
+ Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxi")
47
+ Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhi")
48
+
49
+ i = self.addBiases(Wx, Uh, self.numUnits, name = "i")
50
+
51
+ Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxj")
52
+ Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhj")
53
+
54
+ j = self.addBiases(Wx, Uh, self.numUnits, name = "l")
55
+
56
+ Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxf")
57
+ Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhf")
58
+
59
+ f = self.addBiases(Wx, Uh, self.numUnits, name = "f")
60
+
61
+ Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxo")
62
+ Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uho")
63
+
64
+ o = self.addBiases(Wx, Uh, self.numUnits, name = "o")
65
+ # i, j, f, o = tf.split(value = concat, num_or_size_splits = 4, axis = 1)
66
+
67
+ newC = (c * tf.nn.sigmoid(f + self.forgetBias) + tf.nn.sigmoid(i) *
68
+ self.activation(j))
69
+ newH = self.activation(newC) * tf.nn.sigmoid(o)
70
+
71
+ newState = tf.nn.rnn_cell.LSTMStateTuple(newC, newH)
72
+ return newH, newState
73
+
74
+ def zero_state(self, batchSize, dtype = tf.float32):
75
+ return tf.nn.rnn_cell.LSTMStateTuple(tf.zeros((batchSize, self.numUnits), dtype = dtype),
76
+ tf.zeros((batchSize, self.numUnits), dtype = dtype))
77
+