import tensorflow as tf from tensorflow import keras from transformer import TransformerDecoder import tensorflow_probability as tfp class TFIng(keras.Model): def __init__(self, crop_size, embed_dim, num_layers, seq_length, hidden_dim, num_heads, target_vocab_size, dropout_rate=0.1): super().__init__() self.target_vocab_size = target_vocab_size self.encoder = keras.applications.InceptionV3( include_top=False, weights="imagenet", input_shape=crop_size + (3,), ) self.conv = keras.layers.Conv2D(embed_dim, 1) self.decoder = TransformerDecoder(num_layers, seq_length, embed_dim, hidden_dim, num_heads, target_vocab_size, dropout_rate=dropout_rate) self.linear = keras.layers.Dense(target_vocab_size, activation=None) def call(self, inputs, training=False): encoder_inputs, targets = inputs encoder_out = self.encoder(encoder_inputs, training=training) encoder_out = self.conv(encoder_out, training=training) encoder_out = tf.reshape(encoder_out, (tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3])) decoder_outputs = self.decoder(targets, encoder_out, training=training) output = self.linear(decoder_outputs) return output + self.get_replacement_mask(targets) def get_replacement_mask(self, targets): targets = tf.cast(targets, tf.int32) batch_size, seq_length = tf.shape(targets)[0], tf.shape(targets)[1] n = int(seq_length * (seq_length + 1) / 2) mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32)) mask = tf.repeat(mask[tf.newaxis, :], batch_size, axis=0) targets_repeated = tf.repeat(targets[:, tf.newaxis, :], seq_length, axis=1) targets_masked = targets_repeated * mask columns = tf.boolean_mask( targets_masked, tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked)) ) rows_idx = tf.range(seq_length) rows_idx_repeated = tf.reshape(tf.repeat(rows_idx, seq_length), (seq_length, seq_length)) rows_idx_repeated = tf.repeat(rows_idx_repeated[tf.newaxis, :], batch_size, axis=0) rows = tf.boolean_mask( rows_idx_repeated, tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked)) ) batches_idx = tf.range(batch_size) batches_idx_repeated = tf.reshape( tf.repeat(batches_idx, seq_length * seq_length), (batch_size, seq_length, seq_length) ) batches = tf.boolean_mask( batches_idx_repeated, tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked)) ) idx = tf.stack([batches, rows, columns], axis=1) sparse_mask = tf.SparseTensor( tf.cast(idx, tf.int64), tf.fill([tf.shape(idx)[0]], float('-inf')), [batch_size, seq_length, self.target_vocab_size] ) sparse_mask = tf.sparse.reorder(sparse_mask) return tf.sparse.to_dense(sparse_mask)