File size: 6,110 Bytes
bb2bcd2 a646cdc bb2bcd2 86cb70f 98434b6 20180c8 6c11460 86cb70f 3b553f7 86cb70f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import tensorflow as tf
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
model = tf.keras.models.load_model('VQ-VAE-Model')
class VectorQuantizer(tf.keras.layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (
beta # This parameter is best kept between [0.25, 2] as per the paper.
)
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
),
trainable=True,
name="embeddings_vqvae",
)
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
vq_object = VectorQuantizer(64, 16)
embs = np.load('embeddings.npy')
vq_object.embeddings = embs
encoder = model.layers[1]
#data load and preprocess
_, (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = np.expand_dims(x_test, -1)
x_test_scaled = (x_test / 255.0) - 0.5
def make_subplot_reconstruction(original, reconstructed):
fig, axs = plt.subplots(3,2)
for row_idx in range(3):
axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5);
axs[row_idx,0].axis('off')
axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5);
axs[row_idx,1].axis('off')
axs[0,0].title.set_text("Original")
axs[0,1].title.set_text("Reconstruction")
plt.tight_layout()
fig.set_size_inches(10, 10.5)
return fig
def make_subplot_latent(original, reconstructed):
fig, axs = plt.subplots(3,2)
for row_idx in range(3):
axs[row_idx,0].matshow(original[row_idx].squeeze());
axs[row_idx,0].axis('off')
axs[row_idx,1].matshow(reconstructed[row_idx].squeeze());
axs[row_idx,1].axis('off')
for i in range(7):
for j in range(7):
c = reconstructed[row_idx][i,j]
axs[row_idx,1].text(i, j, str(c), va='center', ha='center')
axs[0,0].title.set_text("Original")
axs[0,1].title.set_text("Discrete Latent Representation")
plt.tight_layout()
fig.set_size_inches(10, 10.5)
return fig
def plot_sample(mode):
sample = np.random.choice(x_test.shape[0], 3)
test_images = x_test_scaled[sample]
if mode=='Reconstruction':
reconstructions_test = model.predict(test_images)
return make_subplot_reconstruction(test_images, reconstructions_test)
encoded_out = encoder.predict(test_images)
encoded = encoded_out.reshape(-1, encoded_out.shape[-1])
quant = vq_object.get_code_indices(encoded)
quant = quant.numpy().reshape(encoded_out.shape[:-1])
return make_subplot_latent(test_images, quant)
demo = gr.Blocks()
with demo:
gr.Markdown("# Vector-Quantized Variational Autoencoders (VQ-VAE)")
gr.Markdown("""This space is to demonstrate the use of VQ-VAEs. Similar to tradiitonal VAEs, VQ-VAEs try to create a useful latent representation.
However, VQ-VAEs latent space is **discrete** rather than continuous. Below, we can view how well this model compresses and reconstructs MNIST digits, but more importantly, we can see a
discretized latent representation. These discrete representations can then be paired with a network like PixelCNN to generate novel images.
VQ-VAEs are one of the tools used by DALL-E and are some of the only models that perform on par with VAEs but with a discrete latent space.
For more information check out this [paper](https://arxiv.org/abs/1711.00937) and
[example](https://keras.io/examples/generative/vq_vae/).<br>
Full Credits for this example go to [Sayak Paul](https://twitter.com/RisingSayak).<br>
Model card can be found [here](https://huggingface.co/brendenc/VQ-VAE).<br>
Demo by [Brenden Connors](https://www.linkedin.com/in/brenden-connors-6a0512195)""")
with gr.Row():
with gr.Column():
with gr.Row():
radio = gr.Radio(choices=['Reconstruction','Discrete Latent Representation'])
with gr.Row():
button = gr.Button('Run')
with gr.Column():
out = gr.Plot()
button.click(plot_sample, radio, out)
demo.launch() |