File size: 3,296 Bytes
86e0937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
import numpy as np

class MiLSTMCell(tf.nn.rnn_cell.RNNCell):
    def __init__(self, num_units, forget_bias = 1.0, input_size = None,
               state_is_tuple = True, activation = tf.tanh, reuse = None):
        self.numUnits = num_units
        self.forgetBias = forget_bias
        self.activation = activation
        self.reuse = reuse

    @property
    def state_size(self):
        return tf.nn.rnn_cell.LSTMStateTuple(self.numUnits, self.numUnits)          

    @property
    def output_size(self):
        return self.numUnits

    def mulWeights(self, inp, inDim, outDim, name = ""):
        with tf.variable_scope("weights" + name):
            W = tf.get_variable("weights", shape = (inDim, outDim),
                initializer = tf.contrib.layers.xavier_initializer())
        output = tf.matmul(inp, W)        
        return output

    def addBiases(self, inp1, inp2, dim, name = ""):
        with tf.variable_scope("additiveBiases" + name):
            b = tf.get_variable("biases", shape = (dim,), 
                initializer = tf.zeros_initializer())
        with tf.variable_scope("multiplicativeBias" + name):
            beta = tf.get_variable("biases", shape = (3 * dim,), 
                initializer = tf.ones_initializer())

        Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1), 
            num_or_size_splits = 3, axis = 1)
        output = Wx + Uh + inter + b
        return output

    def __call__(self, inputs, state, scope = None):
        scope = scope or type(self).__name__
        with tf.variable_scope(scope, reuse = self.reuse):
            c, h = state        
            inputSize = int(inputs.shape[1])

            Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxi")
            Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhi")
            
            i = self.addBiases(Wx, Uh, self.numUnits, name = "i")

            Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxj")
            Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhj")
            
            j = self.addBiases(Wx, Uh, self.numUnits, name = "l")

            Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxf")
            Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhf")
            
            f = self.addBiases(Wx, Uh, self.numUnits, name = "f")

            Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxo")
            Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uho")
            
            o = self.addBiases(Wx, Uh, self.numUnits, name = "o")
            # i, j, f, o = tf.split(value = concat, num_or_size_splits = 4, axis = 1)

            newC = (c * tf.nn.sigmoid(f + self.forgetBias) + tf.nn.sigmoid(i) *
                    self.activation(j))
            newH = self.activation(newC) * tf.nn.sigmoid(o)

            newState = tf.nn.rnn_cell.LSTMStateTuple(newC, newH)
        return newH, newState

    def zero_state(self, batchSize, dtype = tf.float32):
        return tf.nn.rnn_cell.LSTMStateTuple(tf.zeros((batchSize, self.numUnits), dtype = dtype),
                                        tf.zeros((batchSize, self.numUnits), dtype = dtype))