Kurokabe's picture
Upload 84 files
3be620b
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
input_shape = (20, 64, 64, 1)
class VectorQuantizer(layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (
beta # This parameter is best kept between [0.25, 2] as per the paper.
)
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
),
trainable=True,
name="embeddings_vqvae",
)
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
class VQVAE(keras.Model):
def __init__(self, train_variance:float, latent_dim:int=32, num_embeddings:int=128, **kwargs):
super().__init__(**kwargs)
self.train_variance = train_variance
self.latent_dim = latent_dim
self.num_embeddings = num_embeddings
self.vqvae = self.get_vqvae()
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
def get_encoder(self):
encoder_inputs = keras.Input(shape=input_shape)
x = layers.TimeDistributed(layers.Conv2D(32, 3, activation="relu", strides=2, padding="same"))(
encoder_inputs
)
x = layers.TimeDistributed(layers.Conv2D(64, 3, activation="relu", strides=2, padding="same"))(x)
encoder_outputs = layers.TimeDistributed(layers.Conv2D(self.latent_dim, 1, padding="same"))(x)
return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
def get_decoder(self):
latent_inputs = keras.Input(shape=self.get_encoder().output.shape[1:])
x = layers.TimeDistributed(layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same"))(
latent_inputs
)
x = layers.TimeDistributed(layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same"))(x)
decoder_outputs = layers.TimeDistributed(layers.Conv2DTranspose(1, 3, padding="same"))(x)
return keras.Model(latent_inputs, decoder_outputs, name="decoder")
def get_vqvae(self):
self.vq_layer = VectorQuantizer(self.num_embeddings, self.latent_dim, name="vector_quantizer")
self.encoder = self.get_encoder()
self.decoder = self.get_decoder()
inputs = keras.Input(shape=input_shape)
encoder_outputs = self.encoder(inputs)
quantized_latents = self.vq_layer(encoder_outputs)
reconstructions = self.decoder(quantized_latents)
return keras.Model(inputs, reconstructions, name="vq_vae")
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
# Outputs from the VQ-VAE.
reconstructions = self.vqvae(x)
# Calculate the losses.
reconstruction_loss = (
tf.reduce_mean((y - reconstructions) ** 2) / self.train_variance
)
total_loss = reconstruction_loss + sum(self.vqvae.losses)
# Backpropagation.
grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
# Loss tracking.
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
# Log results.
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"vqvae_loss": self.vq_loss_tracker.result(),
}
def call(self, inputs, training=False, mask=None):
return self.vqvae(inputs)