API-Journaling / utils /attention_layer.py
nabilaalt's picture
initial
780a315
raw
history blame contribute delete
866 Bytes
import tensorflow as tf
class AttentionLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(AttentionLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.W = self.add_weight(shape=(input_shape[-1], input_shape[-1]),
initializer="glorot_uniform", trainable=True)
self.b = self.add_weight(shape=(input_shape[-1],),
initializer="zeros", trainable=True)
self.u = self.add_weight(shape=(input_shape[-1], 1),
initializer="glorot_uniform", trainable=True)
def call(self, inputs):
u_t = tf.tanh(tf.tensordot(inputs, self.W, axes=1) + self.b)
score = tf.nn.softmax(tf.tensordot(u_t, self.u, axes=1), axis=1)
context = tf.reduce_sum(inputs * score, axis=1)
return context