|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
import tensorflow as tf |
|
|
|
input_shape = (20, 64, 64, 1) |
|
|
|
class Sampling(keras.layers.Layer): |
|
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" |
|
|
|
def call(self, inputs): |
|
z_mean, z_log_var = inputs |
|
batch = tf.shape(z_mean)[0] |
|
dim = z_mean.shape[1:] |
|
epsilon = tf.keras.backend.random_normal(shape=(batch, *dim)) |
|
return z_mean + tf.exp(0.5 * z_log_var) * epsilon |
|
|
|
def compute_output_shape(self, input_shape): |
|
return input_shape[0] |
|
|
|
|
|
class VAE(keras.Model): |
|
def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs): |
|
super().__init__(**kwargs) |
|
self.latent_dim = latent_dim |
|
self.num_embeddings = num_embeddings |
|
self.beta = beta |
|
|
|
self.encoder = self.get_encoder() |
|
self.decoder = self.get_decoder() |
|
|
|
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") |
|
self.reconstruction_loss_tracker = tf.keras.metrics.Mean( |
|
name="reconstruction_loss" |
|
) |
|
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss") |
|
|
|
|
|
def get_encoder(self): |
|
encoder_inputs = keras.Input(shape=input_shape) |
|
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))( |
|
encoder_inputs |
|
) |
|
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x) |
|
x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x) |
|
|
|
x = layers.TimeDistributed(layers.Flatten())(x) |
|
mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x) |
|
logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x) |
|
z = Sampling()([mu, logvar]) |
|
|
|
return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder") |
|
|
|
|
|
def get_decoder(self): |
|
latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:]) |
|
|
|
x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs) |
|
x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x) |
|
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))( |
|
x |
|
) |
|
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x) |
|
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x) |
|
return keras.Model(latent_inputs, decoder_outputs, name="decoder") |
|
|
|
def train_step(self, data): |
|
x, y = data |
|
|
|
with tf.GradientTape() as tape: |
|
mu, logvar, z = self.encoder(x) |
|
reconstruction = self.decoder(z) |
|
reconstruction_loss = tf.reduce_mean( |
|
tf.reduce_sum( |
|
tf.keras.losses.binary_crossentropy(y, reconstruction), |
|
axis=(1, 2), |
|
) |
|
) |
|
kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar)) |
|
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) |
|
total_loss = reconstruction_loss + self.beta * kl_loss |
|
grads = tape.gradient(total_loss, self.trainable_weights) |
|
self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) |
|
self.total_loss_tracker.update_state(total_loss) |
|
self.reconstruction_loss_tracker.update_state(reconstruction_loss) |
|
self.kl_loss_tracker.update_state(kl_loss) |
|
return { |
|
"loss": self.total_loss_tracker.result(), |
|
"reconstruction_loss": self.reconstruction_loss_tracker.result(), |
|
"kl_loss": self.kl_loss_tracker.result(), |
|
} |
|
|
|
def call(self, inputs, training=False, mask=None): |
|
z_mean, z_log_var, z = self.encoder(inputs) |
|
pred = self.decoder(z) |
|
return pred |
|
|