|
import tensorflow as tf |
|
import tensorflow_probability as tfp |
|
from tensorflow import keras |
|
|
|
class TransformerEncoderLayer(keras.layers.Layer): |
|
def __init__(self, embed_dim, hidden_dim, num_heads, dropout_rate=0.1): |
|
super().__init__() |
|
|
|
self.attention = keras.layers.MultiHeadAttention( |
|
num_heads=num_heads, key_dim=embed_dim |
|
) |
|
self.feed_forward = keras.Sequential( |
|
[ |
|
keras.layers.Dense(hidden_dim, activation="relu"), |
|
keras.layers.Dense(embed_dim, activation=None) |
|
] |
|
) |
|
|
|
self.layernorm1 = keras.layers.LayerNormalization() |
|
self.layernorm2 = keras.layers.LayerNormalization() |
|
|
|
self.dropout1 = keras.layers.Dropout(dropout_rate) |
|
self.dropout2 = keras.layers.Dropout(dropout_rate) |
|
|
|
def call(self, inputs, padding_mask, training=False): |
|
attn_out = self.attention( |
|
query=inputs, |
|
value=inputs, |
|
key=inputs, |
|
attention_mask=padding_mask |
|
) |
|
attn_out = self.dropout1(attn_out, training=training) |
|
x = self.layernorm1(inputs + attn_out) |
|
|
|
ff_out = self.feed_forward(x) |
|
ff_out = self.dropout2(ff_out, training=training) |
|
return self.layernorm2(x + ff_out) |
|
|
|
|
|
class TransformerEncoder(keras.Model): |
|
def __init__(self, num_layers, seq_length, embed_dim, hidden_dim, num_heads, vocab_size, |
|
dropout_rate=0.1): |
|
super().__init__() |
|
|
|
self.embedding = PositionalEmbedding(seq_length, vocab_size, embed_dim) |
|
self.dropout = keras.layers.Dropout(dropout_rate) |
|
self.encoder_layers = [ |
|
TransformerEncoderLayer(embed_dim, hidden_dim, num_heads, dropout_rate=dropout_rate) |
|
for _ in range(num_layers) |
|
] |
|
|
|
def call(self, inputs, padding_mask, training=False): |
|
x = self.embedding(inputs) |
|
x = self.dropout(x, training=training) |
|
for i in range(len(self.encoder_layers)): |
|
x = self.encoder_layers[i](x, padding_mask, training=training) |
|
return x |
|
|
|
|
|
class TransformerDecoderLayer(keras.layers.Layer): |
|
def __init__(self, embed_dim, hidden_dim, num_heads, dropout_rate=0.1): |
|
super().__init__() |
|
|
|
self.self_attention = keras.layers.MultiHeadAttention( |
|
num_heads=num_heads, key_dim=embed_dim |
|
) |
|
self.attention = keras.layers.MultiHeadAttention( |
|
num_heads=num_heads, key_dim=embed_dim |
|
) |
|
self.feed_fordward = keras.Sequential( |
|
[ |
|
keras.layers.Dense(hidden_dim, activation="relu"), |
|
keras.layers.Dense(embed_dim, activation=None) |
|
] |
|
) |
|
|
|
self.layernorm1 = keras.layers.LayerNormalization() |
|
self.layernorm2 = keras.layers.LayerNormalization() |
|
self.layernorm3 = keras.layers.LayerNormalization() |
|
|
|
self.dropout1 = keras.layers.Dropout(dropout_rate) |
|
self.dropout2 = keras.layers.Dropout(dropout_rate) |
|
self.dropout3 = keras.layers.Dropout(dropout_rate) |
|
|
|
def call(self, inputs, encoder_outputs, look_ahead_mask, training=False, padding_mask=None): |
|
self_attn_out = self.self_attention( |
|
query=inputs, |
|
value=inputs, |
|
key=inputs, |
|
attention_mask=look_ahead_mask |
|
) |
|
self_attn_out = self.dropout1(self_attn_out, training=training) |
|
x = self.layernorm1(inputs + self_attn_out) |
|
|
|
attn_out = self.attention( |
|
query=x, |
|
value=encoder_outputs, |
|
key=encoder_outputs, |
|
attention_mask=padding_mask |
|
) |
|
attn_out = self.dropout2(attn_out, training=training) |
|
x = self.layernorm2(x + attn_out) |
|
|
|
ff_out = self.feed_fordward(x) |
|
ff_out = self.dropout3(ff_out, training=training) |
|
return self.layernorm3(x + ff_out) |
|
|
|
|
|
class TransformerDecoder(keras.Model): |
|
def __init__(self, num_layers, seq_length, embed_dim, hidden_dim, num_heads, vocab_size, |
|
dropout_rate=0.1): |
|
super().__init__() |
|
|
|
self.embedding = PositionalEmbedding(seq_length, vocab_size, embed_dim) |
|
self.dropout = keras.layers.Dropout(dropout_rate) |
|
self.decoder_layers = [ |
|
TransformerDecoderLayer(embed_dim, hidden_dim, num_heads, dropout_rate=dropout_rate) |
|
for _ in range(num_layers) |
|
] |
|
|
|
def call(self, inputs, encoder_outputs, training=False, padding_mask=None): |
|
look_ahead_mask = get_look_ahead_mask(inputs) |
|
x = self.embedding(inputs) |
|
x = self.dropout(x, training=training) |
|
for i in range(len(self.decoder_layers)): |
|
x = self.decoder_layers[i](x, encoder_outputs, look_ahead_mask, |
|
training=training, padding_mask=padding_mask) |
|
return x |
|
|
|
|
|
def get_padding_mask(inputs): |
|
mask = tf.cast(tf.math.not_equal(inputs, 0), tf.int32) |
|
return mask[:, tf.newaxis, :] |
|
|
|
def get_look_ahead_mask(inputs): |
|
input_shape = tf.shape(inputs) |
|
batch_size, seq_length = input_shape[0], input_shape[1] |
|
n = int(seq_length * (seq_length + 1) / 2) |
|
mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32)) |
|
mask = tf.repeat(mask[tf.newaxis, :], batch_size, axis=0) |
|
return tf.minimum(mask, get_padding_mask(inputs)) |
|
|
|
|
|
class PositionalEmbedding(keras.layers.Layer): |
|
def __init__(self, seq_length, vocab_size, embed_dim): |
|
super().__init__() |
|
|
|
self.token_embeddings = keras.layers.Embedding( |
|
input_dim=vocab_size, output_dim=embed_dim |
|
) |
|
self.position_embeddings = keras.layers.Embedding( |
|
input_dim=seq_length, output_dim=embed_dim |
|
) |
|
|
|
def call(self, inputs): |
|
positions = tf.range(start=0, limit=tf.shape(inputs)[-1], delta=1) |
|
embedded_tokens = self.token_embeddings(inputs) |
|
embedded_positions = self.position_embeddings(positions) |
|
return embedded_tokens + embedded_positions |
|
|
|
|
|
class Transformer(keras.Model): |
|
def __init__(self, encoder_layers, decoder_layers, input_seq_length, target_seq_length, embed_dim, |
|
hidden_dim, num_heads, input_vocab_size, target_vocab_size, dropout_rate=0.1): |
|
super().__init__() |
|
|
|
self.encoder = TransformerEncoder(encoder_layers, input_seq_length, embed_dim, hidden_dim, |
|
num_heads, input_vocab_size, dropout_rate=dropout_rate) |
|
self.decoder = TransformerDecoder(decoder_layers, target_seq_length, embed_dim, hidden_dim, |
|
num_heads, target_vocab_size, dropout_rate=dropout_rate) |
|
self.linear = keras.layers.Dense(target_vocab_size, activation=None) |
|
|
|
def call(self, inputs, training=False): |
|
encoder_inputs, targets = inputs |
|
padding_mask = get_padding_mask(encoder_inputs) |
|
encoder_outputs = self.encoder(encoder_inputs, padding_mask, training=training) |
|
decoder_outputs = self.decoder(targets, encoder_outputs, training=training, |
|
padding_mask=padding_mask) |
|
return self.linear(decoder_outputs) |