import tensorflow as tf from tensorflow import keras from keras import layers class PositionalEmbedding(layers.Layer): def __init__(self, sequence_length, output_dim, **kwargs): super().__init__(**kwargs) self.position_embeddings = layers.Embedding( input_dim=sequence_length, output_dim=output_dim ) self.sequence_length = sequence_length self.output_dim = output_dim def call(self, inputs): # The inputs are of shape: `(batch_size, frames, num_features)` length = tf.shape(inputs)[1] positions = tf.range(start=0, limit=length, delta=1) embedded_positions = self.position_embeddings(positions) return inputs + embedded_positions def compute_mask(self, inputs, mask=None): mask = tf.reduce_any(tf.cast(inputs, "bool"), axis=-1) return mask def get_config(self): config = super().get_config() config.update({ "sequence_length": self.sequence_length, "output_dim": self.output_dim, }) return config class TransformerEncoder(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.attention = layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=0.3 ) self.dense_proj = keras.Sequential( [layers.Dense(dense_dim, activation=tf.nn.gelu), layers.Dense(embed_dim),] ) self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() def call(self, inputs, mask=None): if mask is not None: mask = mask[:, tf.newaxis, :] attention_output = self.attention(inputs, inputs, attention_mask=mask) proj_input = self.layernorm_1(inputs + attention_output) proj_output = self.dense_proj(proj_input) return self.layernorm_2(proj_input + proj_output) def get_config(self): config = super().get_config() config.update({ "embed_dim": self.embed_dim, "dense_dim": self.dense_dim, "num_heads": self.num_heads, }) return config