import tensorflow as tf import tensorflow_probability as tfp from tensorflow import keras class TransformerEncoderLayer(keras.layers.Layer): def __init__(self, embed_dim, hidden_dim, num_heads, dropout_rate=0.1): super().__init__() self.attention = keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.feed_forward = keras.Sequential( [ keras.layers.Dense(hidden_dim, activation="relu"), keras.layers.Dense(embed_dim, activation=None) ] ) self.layernorm1 = keras.layers.LayerNormalization() self.layernorm2 = keras.layers.LayerNormalization() self.dropout1 = keras.layers.Dropout(dropout_rate) self.dropout2 = keras.layers.Dropout(dropout_rate) def call(self, inputs, padding_mask, training=False): attn_out = self.attention( query=inputs, value=inputs, key=inputs, attention_mask=padding_mask ) attn_out = self.dropout1(attn_out, training=training) x = self.layernorm1(inputs + attn_out) ff_out = self.feed_forward(x) ff_out = self.dropout2(ff_out, training=training) return self.layernorm2(x + ff_out) class TransformerEncoder(keras.Model): def __init__(self, num_layers, seq_length, embed_dim, hidden_dim, num_heads, vocab_size, dropout_rate=0.1): super().__init__() self.embedding = PositionalEmbedding(seq_length, vocab_size, embed_dim) self.dropout = keras.layers.Dropout(dropout_rate) self.encoder_layers = [ TransformerEncoderLayer(embed_dim, hidden_dim, num_heads, dropout_rate=dropout_rate) for _ in range(num_layers) ] def call(self, inputs, padding_mask, training=False): x = self.embedding(inputs) x = self.dropout(x, training=training) for i in range(len(self.encoder_layers)): x = self.encoder_layers[i](x, padding_mask, training=training) return x class TransformerDecoderLayer(keras.layers.Layer): def __init__(self, embed_dim, hidden_dim, num_heads, dropout_rate=0.1): super().__init__() self.self_attention = keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.attention = keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.feed_fordward = keras.Sequential( [ keras.layers.Dense(hidden_dim, activation="relu"), keras.layers.Dense(embed_dim, activation=None) ] ) self.layernorm1 = keras.layers.LayerNormalization() self.layernorm2 = keras.layers.LayerNormalization() self.layernorm3 = keras.layers.LayerNormalization() self.dropout1 = keras.layers.Dropout(dropout_rate) self.dropout2 = keras.layers.Dropout(dropout_rate) self.dropout3 = keras.layers.Dropout(dropout_rate) def call(self, inputs, encoder_outputs, look_ahead_mask, training=False, padding_mask=None): self_attn_out = self.self_attention( query=inputs, value=inputs, key=inputs, attention_mask=look_ahead_mask ) self_attn_out = self.dropout1(self_attn_out, training=training) x = self.layernorm1(inputs + self_attn_out) attn_out = self.attention( query=x, value=encoder_outputs, key=encoder_outputs, attention_mask=padding_mask ) attn_out = self.dropout2(attn_out, training=training) x = self.layernorm2(x + attn_out) ff_out = self.feed_fordward(x) ff_out = self.dropout3(ff_out, training=training) return self.layernorm3(x + ff_out) class TransformerDecoder(keras.Model): def __init__(self, num_layers, seq_length, embed_dim, hidden_dim, num_heads, vocab_size, dropout_rate=0.1): super().__init__() self.embedding = PositionalEmbedding(seq_length, vocab_size, embed_dim) self.dropout = keras.layers.Dropout(dropout_rate) self.decoder_layers = [ TransformerDecoderLayer(embed_dim, hidden_dim, num_heads, dropout_rate=dropout_rate) for _ in range(num_layers) ] def call(self, inputs, encoder_outputs, training=False, padding_mask=None): look_ahead_mask = get_look_ahead_mask(inputs) x = self.embedding(inputs) x = self.dropout(x, training=training) for i in range(len(self.decoder_layers)): x = self.decoder_layers[i](x, encoder_outputs, look_ahead_mask, training=training, padding_mask=padding_mask) return x def get_padding_mask(inputs): mask = tf.cast(tf.math.not_equal(inputs, 0), tf.int32) return mask[:, tf.newaxis, :] def get_look_ahead_mask(inputs): input_shape = tf.shape(inputs) batch_size, seq_length = input_shape[0], input_shape[1] n = int(seq_length * (seq_length + 1) / 2) mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32)) mask = tf.repeat(mask[tf.newaxis, :], batch_size, axis=0) return tf.minimum(mask, get_padding_mask(inputs)) class PositionalEmbedding(keras.layers.Layer): def __init__(self, seq_length, vocab_size, embed_dim): super().__init__() self.token_embeddings = keras.layers.Embedding( input_dim=vocab_size, output_dim=embed_dim ) self.position_embeddings = keras.layers.Embedding( input_dim=seq_length, output_dim=embed_dim ) def call(self, inputs): positions = tf.range(start=0, limit=tf.shape(inputs)[-1], delta=1) embedded_tokens = self.token_embeddings(inputs) embedded_positions = self.position_embeddings(positions) return embedded_tokens + embedded_positions class Transformer(keras.Model): def __init__(self, encoder_layers, decoder_layers, input_seq_length, target_seq_length, embed_dim, hidden_dim, num_heads, input_vocab_size, target_vocab_size, dropout_rate=0.1): super().__init__() self.encoder = TransformerEncoder(encoder_layers, input_seq_length, embed_dim, hidden_dim, num_heads, input_vocab_size, dropout_rate=dropout_rate) self.decoder = TransformerDecoder(decoder_layers, target_seq_length, embed_dim, hidden_dim, num_heads, target_vocab_size, dropout_rate=dropout_rate) self.linear = keras.layers.Dense(target_vocab_size, activation=None) def call(self, inputs, training=False): encoder_inputs, targets = inputs padding_mask = get_padding_mask(encoder_inputs) encoder_outputs = self.encoder(encoder_inputs, padding_mask, training=training) decoder_outputs = self.decoder(targets, encoder_outputs, training=training, padding_mask=padding_mask) return self.linear(decoder_outputs)