File size: 4,009 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
input_shape = (20, 64, 64, 1)
class Sampling(keras.layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = z_mean.shape[1:]
epsilon = tf.keras.backend.random_normal(shape=(batch, *dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
def compute_output_shape(self, input_shape):
return input_shape[0]
class VAE(keras.Model):
def __init__(self, latent_dim:int=32, num_embeddings:int=128, beta:float = 0.5, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.num_embeddings = num_embeddings
self.beta = beta
self.encoder = self.get_encoder()
self.decoder = self.get_decoder()
self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
def get_encoder(self):
encoder_inputs = keras.Input(shape=input_shape)
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
encoder_inputs
)
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
x = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
x = layers.TimeDistributed(layers.Flatten())(x)
mu = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
logvar = layers.TimeDistributed(layers.Dense(self.num_embeddings))(x)
z = Sampling()([mu, logvar])
return keras.Model(encoder_inputs, [mu, logvar, z], name="encoder")
def get_decoder(self):
latent_inputs = keras.Input(shape=self.encoder.output[2].shape[1:])
x = layers.TimeDistributed(layers.Dense(16 * 16 * 32, activation="relu"))(latent_inputs)
x = layers.TimeDistributed(layers.Reshape((16, 16, 32)))(x)
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
x
)
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
mu, logvar, z = self.encoder(x)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
tf.keras.losses.binary_crossentropy(y, reconstruction),
axis=(1, 2),
)
)
kl_loss = -0.5 * (1 + logvar - tf.square(mu) - tf.exp(logvar))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def call(self, inputs, training=False, mask=None):
z_mean, z_log_var, z = self.encoder(inputs)
pred = self.decoder(z)
return pred
|