import tensorflow as tf import numpy as np from transformers import BertTokenizer tokenizer_en = BertTokenizer.from_pretrained("bert-base-cased") tokenizer_cn = BertTokenizer.from_pretrained("bert-base-chinese") MAX_TOKENIZE_LENGTH = 128 EMBEDDING_DEPTH = 256 def positional_encoding(length, depth): depth = depth/2 positions = np.arange(length)[:, np.newaxis] # (seq, 1) depths = np.arange(depth)[np.newaxis, :]/depth # (1, depth) angle_rates = 1 / (10000**depths) # (1, depth) angle_rads = positions * angle_rates # (pos, depth) pos_encoding = np.concatenate( [np.sin(angle_rads), np.cos(angle_rads)], axis=-1) return tf.cast(pos_encoding, dtype=tf.float32) class PositionalEmbedding(tf.keras.layers.Layer): def __init__(self, vocab_size, d_model): super().__init__() self.d_model = d_model self.embedding = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=d_model, mask_zero=True) self.pos_encoding = positional_encoding(length=MAX_TOKENIZE_LENGTH, depth=d_model) def compute_mask(self, *args, **kwargs): return self.embedding.compute_mask(*args, **kwargs) def call(self, x): length = tf.shape(x)[1] x = self.embedding(x) # This factor sets the relative scale of the embedding and positonal_encoding. x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) x = x + self.pos_encoding[tf.newaxis, :length, :] return x class BaseAttention(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__() self.mha = tf.keras.layers.MultiHeadAttention(**kwargs) self.layernorm = tf.keras.layers.LayerNormalization() self.add = tf.keras.layers.Add() class CrossAttention(BaseAttention): def call(self, x, context): #x = query, content = key,value pairs attn_output, attn_scores = self.mha( query=x, key=context, value=context, return_attention_scores=True) # Cache the attention scores for plotting later. self.last_attn_scores = attn_scores x = self.add([x, attn_output]) x = self.layernorm(x) return x class GlobalSelfAttention(BaseAttention): def call(self, x): attn_output = self.mha( query=x, value=x, key=x) x = self.add([x, attn_output]) x = self.layernorm(x) return x class CausalSelfAttention(BaseAttention): def call(self, x): attn_output = self.mha( query=x, value=x, key=x, use_causal_mask = True) x = self.add([x, attn_output]) x = self.layernorm(x) return x class FeedForward(tf.keras.layers.Layer): def __init__(self, d_model, dff, dropout_rate=0.1): super().__init__() self.seq = tf.keras.Sequential([ tf.keras.layers.Dense(dff, activation='relu'), tf.keras.layers.Dense(d_model), tf.keras.layers.Dropout(dropout_rate) ]) self.add = tf.keras.layers.Add() self.layer_norm = tf.keras.layers.LayerNormalization() def call(self, x): x = self.add([x, self.seq(x)]) x = self.layer_norm(x) return x class EncoderLayer(tf.keras.layers.Layer): def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1): super().__init__() self.self_attention = GlobalSelfAttention( num_heads=num_heads, key_dim=d_model, dropout=dropout_rate) self.ffn = FeedForward(d_model, dff) def call(self, x): x = self.self_attention(x) x = self.ffn(x) return x class DecoderLayer(tf.keras.layers.Layer): def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1): super(DecoderLayer, self).__init__() self.causal_self_attention = CausalSelfAttention( num_heads=num_heads, key_dim=d_model, dropout=dropout_rate) self.cross_attention = CrossAttention( num_heads=num_heads, key_dim=d_model, dropout=dropout_rate) self.ffn = FeedForward(d_model, dff) def call(self, x, context): x = self.causal_self_attention(x=x) x = self.cross_attention(x=x, context=context) # Cache the last attention scores for plotting later self.last_attn_scores = self.cross_attention.last_attn_scores x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`. return x class Encoder(tf.keras.layers.Layer): def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1): super().__init__() self.d_model = d_model self.num_layers = num_layers self.pos_embedding = PositionalEmbedding( vocab_size=vocab_size, d_model=d_model) self.enc_layers = [ EncoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate) for _ in range(num_layers)] self.dropout = tf.keras.layers.Dropout(dropout_rate) def call(self, x): # `x` is token-IDs shape: (batch, seq_len) x = self.pos_embedding(x) # Shape `(batch_size, seq_len, d_model)`. # Add dropout. x = self.dropout(x) for i in range(self.num_layers): x = self.enc_layers[i](x) return x # Shape `(batch_size, seq_len, d_model)`. class Decoder(tf.keras.layers.Layer): def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1): super(Decoder, self).__init__() self.d_model = d_model self.num_layers = num_layers self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model) self.dropout = tf.keras.layers.Dropout(dropout_rate) self.dec_layers = [ DecoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate) for _ in range(num_layers)] self.last_attn_scores = None def call(self, x, context): # `x` is token-IDs shape (batch, target_seq_len) x = self.pos_embedding(x) # (batch_size, target_seq_len, d_model) x = self.dropout(x) for i in range(self.num_layers): x = self.dec_layers[i](x, context) self.last_attn_scores = self.dec_layers[-1].last_attn_scores # The shape of x is (batch_size, target_seq_len, d_model). return x # @tf.keras.saving.register_keras_serializable() class Transformer(tf.keras.Model): def __init__(self, *, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, dropout_rate=0.1): super().__init__() self.encoder = Encoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=input_vocab_size, dropout_rate=dropout_rate) self.decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=target_vocab_size, dropout_rate=dropout_rate) self.final_layer = tf.keras.layers.Dense(target_vocab_size) def call(self, inputs): # To use a Keras model with `.fit` you must pass all your inputs in the # first argument. context, x = inputs context = self.encoder(context) # (batch_size, context_len, d_model) x = self.decoder(x, context) # (batch_size, target_len, d_model) # Final linear layer output. logits = self.final_layer(x) # (batch_size, target_len, target_vocab_size) try: # Drop the keras mask, so it doesn't scale the losses/metrics. # b/250038731 del logits._keras_mask except AttributeError: pass # Return the final output and the attention weights. return logits # @tf.keras.saving.register_keras_serializable() # class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): # def __init__(self, d_model, warmup_steps=4000): # super().__init__() # self.d_model = d_model # self.d_model = tf.cast(self.d_model, tf.float32) # self.warmup_steps = warmup_steps # def __call__(self, step): # step = tf.cast(step, dtype=tf.float32) # arg1 = tf.math.rsqrt(step) # arg2 = step * (self.warmup_steps ** -1.5) # return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) # def get_config(self): # return { # 'd_model': int(self.d_model), # 'warmup_steps': int(self.warmup_steps) # } # # learning_rate = CustomSchedule(EMBEDDING_DEPTH) # # @tf.keras.saving.register_keras_serializable() # class CustomAdam(tf.keras.optimizers.Adam): # def __init__(self, custom_param, **kwargs): # super(CustomAdam, self).__init__(**kwargs) # self.custom_param = custom_param #this is the learning rate (custom schedule) # def get_config(self): # config = super(CustomAdam, self).get_config() # config.update({ # 'custom_param': self.custom_param # }) # return config # # optimizer = CustomAdam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) # # @tf.keras.saving.register_keras_serializable() # def masked_loss(label, pred): # mask = label != 0 # loss_object = tf.keras.losses.SparseCategoricalCrossentropy( # from_logits=True, reduction='none') # loss = loss_object(label, pred) # mask = tf.cast(mask, dtype=loss.dtype) # loss *= mask # loss = tf.reduce_sum(loss)/tf.reduce_sum(mask) # return loss # # @tf.keras.saving.register_keras_serializable() # def masked_accuracy(label, pred): # pred = tf.argmax(pred, axis=2) # label = tf.cast(label, pred.dtype) # match = label == pred # mask = label != 0 # match = match & mask # match = tf.cast(match, dtype=tf.float32) # mask = tf.cast(mask, dtype=tf.float32) # return tf.reduce_sum(match)/tf.reduce_sum(mask)