Spaces:
Runtime error
Runtime error
Upload mi_gru_cell.py
Browse files- mi_gru_cell.py +64 -0
mi_gru_cell.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class MiGRUCell(tf.nn.rnn_cell.RNNCell):
|
5 |
+
def __init__(self, num_units, input_size = None, activation = tf.tanh, reuse = None):
|
6 |
+
self.numUnits = num_units
|
7 |
+
self.activation = activation
|
8 |
+
self.reuse = reuse
|
9 |
+
|
10 |
+
@property
|
11 |
+
def state_size(self):
|
12 |
+
return self.numUnits
|
13 |
+
|
14 |
+
@property
|
15 |
+
def output_size(self):
|
16 |
+
return self.numUnits
|
17 |
+
|
18 |
+
def mulWeights(self, inp, inDim, outDim, name = ""):
|
19 |
+
with tf.variable_scope("weights" + name):
|
20 |
+
W = tf.get_variable("weights", shape = (inDim, outDim),
|
21 |
+
initializer = tf.contrib.layers.xavier_initializer())
|
22 |
+
|
23 |
+
output = tf.matmul(inp, W)
|
24 |
+
return output
|
25 |
+
|
26 |
+
def addBiases(self, inp1, inp2, dim, bInitial = 0, name = ""):
|
27 |
+
with tf.variable_scope("additiveBiases" + name):
|
28 |
+
b = tf.get_variable("biases", shape = (dim,),
|
29 |
+
initializer = tf.zeros_initializer()) + bInitial
|
30 |
+
with tf.variable_scope("multiplicativeBias" + name):
|
31 |
+
beta = tf.get_variable("biases", shape = (3 * dim,),
|
32 |
+
initializer = tf.ones_initializer())
|
33 |
+
|
34 |
+
Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
|
35 |
+
num_or_size_splits = 3, axis = 1)
|
36 |
+
output = Wx + Uh + inter + b
|
37 |
+
return output
|
38 |
+
|
39 |
+
def __call__(self, inputs, state, scope = None):
|
40 |
+
scope = scope or type(self).__name__
|
41 |
+
with tf.variable_scope(scope, reuse = self.reuse):
|
42 |
+
inputSize = int(inputs.shape[1])
|
43 |
+
|
44 |
+
Wxr = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxr")
|
45 |
+
Uhr = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhr")
|
46 |
+
|
47 |
+
r = tf.nn.sigmoid(self.addBiases(Wxr, Uhr, self.numUnits, bInitial = 1, name = "r"))
|
48 |
+
|
49 |
+
Wxu = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxu")
|
50 |
+
Uhu = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhu")
|
51 |
+
|
52 |
+
u = tf.nn.sigmoid(self.addBiases(Wxu, Uhu, self.numUnits, bInitial = 1, name = "u"))
|
53 |
+
# r, u = tf.split(gates, num_or_size_splits = 2, axis = 1)
|
54 |
+
|
55 |
+
Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxl")
|
56 |
+
Urh = self.mulWeights(r * state, self.numUnits, self.numUnits, name = "Uhl")
|
57 |
+
c = self.activation(self.addBiases(Wx, Urh, self.numUnits, name = "2"))
|
58 |
+
|
59 |
+
newH = u * state + (1 - u) * c # switch u and 1-u?
|
60 |
+
return newH, newH
|
61 |
+
|
62 |
+
def zero_state(self, batchSize, dtype = tf.float32):
|
63 |
+
return tf.zeros((batchSize, self.numUnits), dtype = dtype)
|
64 |
+
|