File size: 481 Bytes
301ca84
 
 
 
 
e2692aa
 
301ca84
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf

class TransformerModel(tf.keras.Model):
    def __init__(self, config):
        super(TransformerModel, self).__init__()
        self.encoder = tf.keras.layers.Transformer(**config["encoder_params"])
        self.decoder = tf.keras.layers.Transformer(**config["decoder_params"])

    def call(self, inputs, targets):
        encoder_output = self.encoder(inputs)
        decoder_output = self.decoder(targets, encoder_output)
        return decoder_output