video-transformers / utils /custom_layers.py
shivi's picture
Adding all app files with examples
3d12539
import tensorflow as tf
from tensorflow import keras
from keras import layers
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, output_dim, **kwargs):
super().__init__(**kwargs)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim
)
self.sequence_length = sequence_length
self.output_dim = output_dim
def call(self, inputs):
# The inputs are of shape: `(batch_size, frames, num_features)`
length = tf.shape(inputs)[1]
positions = tf.range(start=0, limit=length, delta=1)
embedded_positions = self.position_embeddings(positions)
return inputs + embedded_positions
def compute_mask(self, inputs, mask=None):
mask = tf.reduce_any(tf.cast(inputs, "bool"), axis=-1)
return mask
def get_config(self):
config = super().get_config()
config.update({
"sequence_length": self.sequence_length,
"output_dim": self.output_dim,
})
return config
class TransformerEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=0.3
)
self.dense_proj = keras.Sequential(
[layers.Dense(dense_dim, activation=tf.nn.gelu), layers.Dense(embed_dim),]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
def call(self, inputs, mask=None):
if mask is not None:
mask = mask[:, tf.newaxis, :]
attention_output = self.attention(inputs, inputs, attention_mask=mask)
proj_input = self.layernorm_1(inputs + attention_output)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)
def get_config(self):
config = super().get_config()
config.update({
"embed_dim": self.embed_dim,
"dense_dim": self.dense_dim,
"num_heads": self.num_heads,
})
return config