File size: 4,009 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

input_shape = (20, 64, 64, 1)

class Sampling(keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = z_mean.shape[1:]
        epsilon = tf.keras.backend.random_normal(shape=(batch, *dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def compute_output_shape(self, input_shape):
        return input_shape[0]


class VAE(keras.Model):
    def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings
        self.beta = beta

        self.encoder = self.get_encoder()
        self.decoder = self.get_decoder()

        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")


    def get_encoder(self):
        encoder_inputs = keras.Input(shape=input_shape)
        x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
            encoder_inputs
        )
        x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
        x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
        
        x = layers.TimeDistributed(layers.Flatten())(x)
        mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
        logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
        z = Sampling()([mu, logvar])
        
        return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder")


    def get_decoder(self):
        latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:])

        x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs)
        x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x)
        x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
            x
        )
        x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
        decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
        return keras.Model(latent_inputs, decoder_outputs, name="decoder")

    def train_step(self, data):
        x, y = data
        
        with tf.GradientTape() as tape:
            mu, logvar, z = self.encoder(x)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(y, reconstruction),
                    axis=(1, 2),
                )
            )
            kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + self.beta * kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def call(self, inputs, training=False, mask=None):
        z_mean, z_log_var, z = self.encoder(inputs)
        pred = self.decoder(z)
        return pred