VQ-VAE / app.py
brendenc's picture
Update app.py
a13e610
raw history blame
No virus
4.84 kB
import tensorflow as tf
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
model = tf.saved_model.load('VQ-VAE-Model')
class VectorQuantizer(tf.keras.layers.Layer):
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = (
beta # This parameter is best kept between [0.25, 2] as per the paper.
)
# Initialize the embeddings which we will quantize.
w_init = tf.random_uniform_initializer()
self.embeddings = tf.Variable(
initial_value=w_init(
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
),
trainable=True,
name="embeddings_vqvae",
)
def call(self, x):
# Calculate the input shape of the inputs and
# then flatten the inputs keeping `embedding_dim` intact.
input_shape = tf.shape(x)
flattened = tf.reshape(x, [-1, self.embedding_dim])
# Quantization.
encoding_indices = self.get_code_indices(flattened)
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
quantized = tf.reshape(quantized, input_shape)
# Calculate vector quantization loss and add that to the layer. You can learn more
# about adding losses to different layers here:
# https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
# the original paper to get a handle on the formulation of the loss function.
commitment_loss = self.beta * tf.reduce_mean(
(tf.stop_gradient(quantized) - x) ** 2
)
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
self.add_loss(commitment_loss + codebook_loss)
# Straight-through estimator.
quantized = x + tf.stop_gradient(quantized - x)
return quantized
def get_code_indices(self, flattened_inputs):
# Calculate L2-normalized distance between the inputs and the codes.
similarity = tf.matmul(flattened_inputs, self.embeddings)
distances = (
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
+ tf.reduce_sum(self.embeddings ** 2, axis=0)
- 2 * similarity
)
# Derive the indices for minimum distances.
encoding_indices = tf.argmin(distances, axis=1)
return encoding_indices
vq_object = VectorQuantizer(64, 16)
embs = np.load('embeddings.npy')
vq_object.embeddings = embs
model = tf.keras.models.load_model('VQ-VAE', custom_objects={'vector_quantizer':vq_object})
encoder = model.layers[1]
#data load and preprocess
_, (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = np.expand_dims(x_test, -1)
x_test_scaled = (x_test / 255.0) - 0.5
def make_subplot_reconstruction(original, reconstructed):
fig, axs = plt.subplots(3,2)
for row_idx in range(3):
axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5);
axs[row_idx,0].axis('off')
axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5);
axs[row_idx,1].axis('off')
axs[0,0].title.set_text("Original")
axs[0,1].title.set_text("Reconstruction")
plt.tight_layout()
fig.set_size_inches(10, 10.5)
return fig
def make_subplot_latent(original, reconstructed):
fig, axs = plt.subplots(3,2)
for row_idx in range(3):
axs[row_idx,0].matshow(original[row_idx].squeeze());
axs[row_idx,0].axis('off')
axs[row_idx,1].matshow(reconstructed[row_idx].squeeze());
axs[row_idx,1].axis('off')
for i in range(7):
for j in range(7):
c = reconstructed[row_idx][i,j]
axs[row_idx,1].text(i, j, str(c), va='center', ha='center')
axs[0,0].title.set_text("Original")
axs[0,1].title.set_text("Discrete Latent Representation")
plt.tight_layout()
fig.set_size_inches(10, 10.5)
return fig
def plot_sample(mode):
sample = np.random.choice(x_test.shape[0], 3)
test_images = x_test_scaled[sample]
if mode=='Reconstruction':
reconstructions_test = model.predict(test_images)
return make_subplot_reconstruction(test_images, reconstructions_test)
encoded_out = encoder.predict(test_images)
encoded = encoded_out.reshape(-1, encoded_out.shape[-1])
quant = vq_object.get_code_indices(encoded)
quant = quant.numpy().reshape(encoded_out.shape[:-1])
return make_subplot_latent(test_images, quant)
import gradio as gr
radio = gr.Radio(choices=['Reconstruction','Latent Representation'])
out = gr.Plot()
gr.Interface(plot_sample, radio, out).launch()