VQ-VAE / app.py
brendenc's picture
Update app.py
import tensorflow as tf
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
model = tf.keras.models.load_model('VQ-VAE-Model')
class VectorQuantizer(tf.keras.layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (
beta # This parameter is best kept between [0.25, 2] as per the paper.
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=0)
- 2 * similarity
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
vq_object = VectorQuantizer(64, 16)
embs = np.load('embeddings.npy')
vq_object.embeddings = embs
encoder = model.layers[1]
#data load and preprocess
_, (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = np.expand_dims(x_test, -1)
x_test_scaled = (x_test / 255.0) - 0.5
def make_subplot_reconstruction(original, reconstructed):
fig, axs = plt.subplots(3,2)
for row_idx in range(3):
axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5);
axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5);
fig.set_size_inches(10, 10.5)
return fig
def make_subplot_latent(original, reconstructed):
fig, axs = plt.subplots(3,2)
for row_idx in range(3):
for i in range(7):
for j in range(7):
c = reconstructed[row_idx][i,j]
axs[row_idx,1].text(i, j, str(c), va='center', ha='center')
axs[0,1].title.set_text("Discrete Latent Representation")
fig.set_size_inches(10, 10.5)
return fig
def plot_sample(mode):
sample = np.random.choice(x_test.shape[0], 3)
test_images = x_test_scaled[sample]
if mode=='Reconstruction':
reconstructions_test = model.predict(test_images)
return make_subplot_reconstruction(test_images, reconstructions_test)
encoded_out = encoder.predict(test_images)
encoded = encoded_out.reshape(-1, encoded_out.shape[-1])
quant = vq_object.get_code_indices(encoded)
quant = quant.numpy().reshape(encoded_out.shape[:-1])
return make_subplot_latent(test_images, quant)
demo = gr.Blocks()
with demo:
gr.Markdown("# Vector-Quantized Variational Autoencoders (VQ-VAE)")
gr.Markdown("""This space is to demonstrate the use of VQ-VAEs. Similar to tradiitonal VAEs, VQ-VAEs try to create a useful latent representation.
However, VQ-VAEs latent space is **discrete** rather than continuous. Below, we can view how well this model compresses and reconstructs MNIST digits, but more importantly, we can see a
discretized latent representation. These discrete representations can then be paired with a network like PixelCNN to generate novel images.
VQ-VAEs are one of the tools used by DALL-E and are some of the only models that perform on par with VAEs but with a discrete latent space.
For more information check out this [paper](https://arxiv.org/abs/1711.00937) and
Full Credits for this example go to [Sayak Paul](https://twitter.com/RisingSayak).<br>
Model card can be found [here](https://huggingface.co/brendenc/VQ-VAE).<br>
Demo by [Brenden Connors](https://www.linkedin.com/in/brenden-connors-6a0512195)""")
with gr.Row():
with gr.Column():
with gr.Row():
radio = gr.Radio(choices=['Reconstruction','Discrete Latent Representation'])
with gr.Row():
button = gr.Button('Run')
with gr.Column():
out = gr.Plot()
button.click(plot_sample, radio, out)