ydin0771 commited on
Commit
5115fbb
β€’
1 Parent(s): 86e0937

Upload mi_gru_cell.py

Browse files
Files changed (1) hide show
  1. mi_gru_cell.py +64 -0
mi_gru_cell.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+
4
+ class MiGRUCell(tf.nn.rnn_cell.RNNCell):
5
+ def __init__(self, num_units, input_size = None, activation = tf.tanh, reuse = None):
6
+ self.numUnits = num_units
7
+ self.activation = activation
8
+ self.reuse = reuse
9
+
10
+ @property
11
+ def state_size(self):
12
+ return self.numUnits
13
+
14
+ @property
15
+ def output_size(self):
16
+ return self.numUnits
17
+
18
+ def mulWeights(self, inp, inDim, outDim, name = ""):
19
+ with tf.variable_scope("weights" + name):
20
+ W = tf.get_variable("weights", shape = (inDim, outDim),
21
+ initializer = tf.contrib.layers.xavier_initializer())
22
+
23
+ output = tf.matmul(inp, W)
24
+ return output
25
+
26
+ def addBiases(self, inp1, inp2, dim, bInitial = 0, name = ""):
27
+ with tf.variable_scope("additiveBiases" + name):
28
+ b = tf.get_variable("biases", shape = (dim,),
29
+ initializer = tf.zeros_initializer()) + bInitial
30
+ with tf.variable_scope("multiplicativeBias" + name):
31
+ beta = tf.get_variable("biases", shape = (3 * dim,),
32
+ initializer = tf.ones_initializer())
33
+
34
+ Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
35
+ num_or_size_splits = 3, axis = 1)
36
+ output = Wx + Uh + inter + b
37
+ return output
38
+
39
+ def __call__(self, inputs, state, scope = None):
40
+ scope = scope or type(self).__name__
41
+ with tf.variable_scope(scope, reuse = self.reuse):
42
+ inputSize = int(inputs.shape[1])
43
+
44
+ Wxr = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxr")
45
+ Uhr = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhr")
46
+
47
+ r = tf.nn.sigmoid(self.addBiases(Wxr, Uhr, self.numUnits, bInitial = 1, name = "r"))
48
+
49
+ Wxu = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxu")
50
+ Uhu = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhu")
51
+
52
+ u = tf.nn.sigmoid(self.addBiases(Wxu, Uhu, self.numUnits, bInitial = 1, name = "u"))
53
+ # r, u = tf.split(gates, num_or_size_splits = 2, axis = 1)
54
+
55
+ Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxl")
56
+ Urh = self.mulWeights(r * state, self.numUnits, self.numUnits, name = "Uhl")
57
+ c = self.activation(self.addBiases(Wx, Urh, self.numUnits, name = "2"))
58
+
59
+ newH = u * state + (1 - u) * c # switch u and 1-u?
60
+ return newH, newH
61
+
62
+ def zero_state(self, batchSize, dtype = tf.float32):
63
+ return tf.zeros((batchSize, self.numUnits), dtype = dtype)
64
+