Ragdoll / models /embeddings.py
abhaskumarsinha's picture
Added the alpha version of Corpus2GPT
ceed47a
raw
history blame
577 Bytes
import math
import keras
import numpy as np
class TokenAndPositionEmbedding(keras.layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
self.token_emb = keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
maxlen = keras.ops.shape(x)[-1]
positions = keras.ops.arange(0, maxlen, 1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions