|
import tensorflow as tf |
|
from library.self_attention import scaled_dot_product_attention |
|
|
|
class MultiHeadAttention(tf.keras.layers.Layer): |
|
def __init__(self, d_model, num_heads): |
|
super(MultiHeadAttention, self).__init__() |
|
self.num_heads = num_heads |
|
self.d_model = d_model |
|
|
|
assert d_model % self.num_heads == 0 |
|
|
|
self.depth = d_model // self.num_heads |
|
|
|
self.wq = tf.keras.layers.Dense(d_model) |
|
self.wk = tf.keras.layers.Dense(d_model) |
|
self.wv = tf.keras.layers.Dense(d_model) |
|
|
|
self.dense = tf.keras.layers.Dense(d_model) |
|
|
|
def split_heads(self, x, batch_size): |
|
"""Split the last dimension into (num_heads, depth). |
|
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) |
|
""" |
|
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) |
|
return tf.transpose(x, perm=[0, 2, 1, 3]) |
|
|
|
def call(self, v, k, q, mask=None): |
|
batch_size = tf.shape(q)[0] |
|
|
|
q = self.wq(q) |
|
k = self.wk(k) |
|
v = self.wv(v) |
|
|
|
q = self.split_heads(q, batch_size) |
|
k = self.split_heads(k, batch_size) |
|
v = self.split_heads(v, batch_size) |
|
|
|
|
|
|
|
scaled_attention, attention_weights = scaled_dot_product_attention( |
|
q, k, v, mask) |
|
|
|
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) |
|
|
|
concat_attention = tf.reshape(scaled_attention, |
|
(batch_size, -1, self.d_model)) |
|
|
|
output = self.dense(concat_attention) |
|
|
|
return output, attention_weights |
|
|
|
|
|
def point_wise_feed_forward_network(d_model, dff): |
|
return tf.keras.Sequential([ |
|
tf.keras.layers.Dense(dff, activation='relu'), |
|
tf.keras.layers.Dense(d_model) |
|
]) |