davidaf3's picture
Removed relative imports
bf8a52d
raw history blame
No virus
3.21 kB
import tensorflow as tf
from tensorflow import keras
from transformer import TransformerDecoder
import tensorflow_probability as tfp
class TFIng(keras.Model):
def __init__(self, crop_size, embed_dim, num_layers, seq_length, hidden_dim, num_heads,
target_vocab_size, dropout_rate=0.1):
super().__init__()
self.target_vocab_size = target_vocab_size
self.encoder = keras.applications.InceptionV3(
include_top=False,
weights="imagenet",
input_shape=crop_size + (3,),
)
self.conv = keras.layers.Conv2D(embed_dim, 1)
self.decoder = TransformerDecoder(num_layers, seq_length, embed_dim, hidden_dim, num_heads,
target_vocab_size, dropout_rate=dropout_rate)
self.linear = keras.layers.Dense(target_vocab_size, activation=None)
def call(self, inputs, training=False):
encoder_inputs, targets = inputs
encoder_out = self.encoder(encoder_inputs, training=training)
encoder_out = self.conv(encoder_out, training=training)
encoder_out = tf.reshape(encoder_out, (tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3]))
decoder_outputs = self.decoder(targets, encoder_out, training=training)
output = self.linear(decoder_outputs)
return output + self.get_replacement_mask(targets)
def get_replacement_mask(self, targets):
targets = tf.cast(targets, tf.int32)
batch_size, seq_length = tf.shape(targets)[0], tf.shape(targets)[1]
n = int(seq_length * (seq_length + 1) / 2)
mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32))
mask = tf.repeat(mask[tf.newaxis, :], batch_size, axis=0)
targets_repeated = tf.repeat(targets[:, tf.newaxis, :], seq_length, axis=1)
targets_masked = targets_repeated * mask
columns = tf.boolean_mask(
targets_masked,
tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked))
)
rows_idx = tf.range(seq_length)
rows_idx_repeated = tf.reshape(tf.repeat(rows_idx, seq_length), (seq_length, seq_length))
rows_idx_repeated = tf.repeat(rows_idx_repeated[tf.newaxis, :], batch_size, axis=0)
rows = tf.boolean_mask(
rows_idx_repeated,
tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked))
)
batches_idx = tf.range(batch_size)
batches_idx_repeated = tf.reshape(
tf.repeat(batches_idx, seq_length * seq_length), (batch_size, seq_length, seq_length)
)
batches = tf.boolean_mask(
batches_idx_repeated,
tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked))
)
idx = tf.stack([batches, rows, columns], axis=1)
sparse_mask = tf.SparseTensor(
tf.cast(idx, tf.int64),
tf.fill([tf.shape(idx)[0]], float('-inf')),
[batch_size, seq_length, self.target_vocab_size]
)
sparse_mask = tf.sparse.reorder(sparse_mask)
return tf.sparse.to_dense(sparse_mask)