File size: 481 Bytes
301ca84 e2692aa 301ca84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import tensorflow as tf
class TransformerModel(tf.keras.Model):
def __init__(self, config):
super(TransformerModel, self).__init__()
self.encoder = tf.keras.layers.Transformer(**config["encoder_params"])
self.decoder = tf.keras.layers.Transformer(**config["decoder_params"])
def call(self, inputs, targets):
encoder_output = self.encoder(inputs)
decoder_output = self.decoder(targets, encoder_output)
return decoder_output
|