import tensorflow as tf import gradio as gr import matplotlib.pyplot as plt import numpy as np model = tf.keras.models.load_model('VQ-VAE-Model') class VectorQuantizer(tf.keras.layers.Layer): def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.beta = ( beta # This parameter is best kept between [0.25, 2] as per the paper. ) # Initialize the embeddings which we will quantize. w_init = tf.random_uniform_initializer() self.embeddings = tf.Variable( initial_value=w_init( shape=(self.embedding_dim, self.num_embeddings), dtype="float32" ), trainable=True, name="embeddings_vqvae", ) def call(self, x): # Calculate the input shape of the inputs and # then flatten the inputs keeping `embedding_dim` intact. input_shape = tf.shape(x) flattened = tf.reshape(x, [-1, self.embedding_dim]) # Quantization. encoding_indices = self.get_code_indices(flattened) encodings = tf.one_hot(encoding_indices, self.num_embeddings) quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) quantized = tf.reshape(quantized, input_shape) # Calculate vector quantization loss and add that to the layer. You can learn more # about adding losses to different layers here: # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check # the original paper to get a handle on the formulation of the loss function. commitment_loss = self.beta * tf.reduce_mean( (tf.stop_gradient(quantized) - x) ** 2 ) codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) self.add_loss(commitment_loss + codebook_loss) # Straight-through estimator. quantized = x + tf.stop_gradient(quantized - x) return quantized def get_code_indices(self, flattened_inputs): # Calculate L2-normalized distance between the inputs and the codes. similarity = tf.matmul(flattened_inputs, self.embeddings) distances = ( tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True) + tf.reduce_sum(self.embeddings ** 2, axis=0) - 2 * similarity ) # Derive the indices for minimum distances. encoding_indices = tf.argmin(distances, axis=1) return encoding_indices vq_object = VectorQuantizer(64, 16) embs = np.load('embeddings.npy') vq_object.embeddings = embs encoder = model.layers[1] #data load and preprocess _, (x_test, _) = tf.keras.datasets.mnist.load_data() x_test = np.expand_dims(x_test, -1) x_test_scaled = (x_test / 255.0) - 0.5 def make_subplot_reconstruction(original, reconstructed): fig, axs = plt.subplots(3,2) for row_idx in range(3): axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5); axs[row_idx,0].axis('off') axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5); axs[row_idx,1].axis('off') axs[0,0].title.set_text("Original") axs[0,1].title.set_text("Reconstruction") plt.tight_layout() fig.set_size_inches(10, 10.5) return fig def make_subplot_latent(original, reconstructed): fig, axs = plt.subplots(3,2) for row_idx in range(3): axs[row_idx,0].matshow(original[row_idx].squeeze()); axs[row_idx,0].axis('off') axs[row_idx,1].matshow(reconstructed[row_idx].squeeze()); axs[row_idx,1].axis('off') for i in range(7): for j in range(7): c = reconstructed[row_idx][i,j] axs[row_idx,1].text(i, j, str(c), va='center', ha='center') axs[0,0].title.set_text("Original") axs[0,1].title.set_text("Discrete Latent Representation") plt.tight_layout() fig.set_size_inches(10, 10.5) return fig def plot_sample(mode): sample = np.random.choice(x_test.shape[0], 3) test_images = x_test_scaled[sample] if mode=='Reconstruction': reconstructions_test = model.predict(test_images) return make_subplot_reconstruction(test_images, reconstructions_test) encoded_out = encoder.predict(test_images) encoded = encoded_out.reshape(-1, encoded_out.shape[-1]) quant = vq_object.get_code_indices(encoded) quant = quant.numpy().reshape(encoded_out.shape[:-1]) return make_subplot_latent(test_images, quant) demo = gr.Blocks() with demo: gr.Markdown("# Vector-Quantized Variational Autoencoders (VQ-VAE)") gr.Markdown("""This space is to demonstrate the use of VQ-VAEs. Similar to tradiitonal VAEs, VQ-VAEs try to create a useful latent representation. However, VQ-VAEs latent space is **discrete** rather than continuous. Below, we can view how well this model compresses and reconstructs MNIST digits, but more importantly, we can see a discretized latent representation. These discrete representations can then be paired with a network like PixelCNN to generate novel images. VQ-VAEs are one of the tools used by DALL-E and are some of the only models that perform on par with VAEs but with a discrete latent space. For more information check out this [paper](https://arxiv.org/abs/1711.00937) and [example](https://keras.io/examples/generative/vq_vae/).
Full Credits for this example go to [Sayak Paul](https://twitter.com/RisingSayak).
Model card can be found [here](https://huggingface.co/brendenc/VQ-VAE).
Demo by [Brenden Connors](https://www.linkedin.com/in/brenden-connors-6a0512195)""") with gr.Row(): with gr.Column(): with gr.Row(): radio = gr.Radio(choices=['Reconstruction','Discrete Latent Representation']) with gr.Row(): button = gr.Button('Run') with gr.Column(): out = gr.Plot() button.click(plot_sample, radio, out) demo.launch()