File size: 3,210 Bytes
79a66df
 
bf8a52d
79a66df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import tensorflow as tf
from tensorflow import keras
from transformer import TransformerDecoder
import tensorflow_probability as tfp

class TFIng(keras.Model):
    def __init__(self, crop_size, embed_dim, num_layers, seq_length, hidden_dim, num_heads, 
                    target_vocab_size, dropout_rate=0.1):
        super().__init__()
        self.target_vocab_size = target_vocab_size

        self.encoder = keras.applications.InceptionV3(
            include_top=False, 
            weights="imagenet", 
            input_shape=crop_size + (3,),
        )
        self.conv = keras.layers.Conv2D(embed_dim, 1)
        self.decoder = TransformerDecoder(num_layers, seq_length, embed_dim, hidden_dim, num_heads, 
                                            target_vocab_size, dropout_rate=dropout_rate)
        self.linear = keras.layers.Dense(target_vocab_size, activation=None)

    def call(self, inputs, training=False):
        encoder_inputs, targets = inputs
        encoder_out = self.encoder(encoder_inputs, training=training)
        encoder_out = self.conv(encoder_out, training=training)
        encoder_out = tf.reshape(encoder_out, (tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3]))
        decoder_outputs = self.decoder(targets, encoder_out, training=training)
        output = self.linear(decoder_outputs)
        return output + self.get_replacement_mask(targets)

    def get_replacement_mask(self, targets):
        targets = tf.cast(targets, tf.int32)
        batch_size, seq_length = tf.shape(targets)[0], tf.shape(targets)[1]

        n = int(seq_length * (seq_length  + 1) / 2)
        mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32))
        mask = tf.repeat(mask[tf.newaxis, :], batch_size, axis=0)

        targets_repeated = tf.repeat(targets[:, tf.newaxis, :], seq_length, axis=1)
        targets_masked = targets_repeated * mask
        columns = tf.boolean_mask(
            targets_masked, 
            tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked))
        )

        rows_idx = tf.range(seq_length)
        rows_idx_repeated = tf.reshape(tf.repeat(rows_idx, seq_length), (seq_length, seq_length))
        rows_idx_repeated = tf.repeat(rows_idx_repeated[tf.newaxis, :], batch_size, axis=0)
        rows = tf.boolean_mask(
            rows_idx_repeated, 
            tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked))
        )

        batches_idx = tf.range(batch_size)
        batches_idx_repeated = tf.reshape(
            tf.repeat(batches_idx, seq_length * seq_length), (batch_size, seq_length, seq_length)
        )
        batches = tf.boolean_mask(
            batches_idx_repeated, 
            tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked))
        )

        idx = tf.stack([batches, rows, columns], axis=1)

        sparse_mask = tf.SparseTensor(
            tf.cast(idx, tf.int64), 
            tf.fill([tf.shape(idx)[0]], float('-inf')), 
            [batch_size, seq_length, self.target_vocab_size]
        )
        sparse_mask = tf.sparse.reorder(sparse_mask)
        return tf.sparse.to_dense(sparse_mask)