File size: 6,110 Bytes
bb2bcd2
 
 
 
 
a646cdc
bb2bcd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86cb70f
 
 
 
 
 
 
 
98434b6
 
20180c8
 
 
6c11460
86cb70f
 
 
 
3b553f7
86cb70f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import tensorflow as tf
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np

model = tf.keras.models.load_model('VQ-VAE-Model')

class VectorQuantizer(tf.keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = (
            beta  # This parameter is best kept between [0.25, 2] as per the paper.
        )

        # Initialize the embeddings which we will quantize.
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization.
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer. You can learn more
        # about adding losses to different layers here:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
        # the original paper to get a handle on the formulation of the loss function.
        commitment_loss = self.beta * tf.reduce_mean(
            (tf.stop_gradient(quantized) - x) ** 2
        )
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings ** 2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices

vq_object = VectorQuantizer(64, 16)
embs = np.load('embeddings.npy')
vq_object.embeddings = embs
encoder = model.layers[1]

#data load and preprocess
_, (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = np.expand_dims(x_test, -1)
x_test_scaled = (x_test / 255.0) - 0.5

def make_subplot_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(3,2)
    for row_idx in range(3):
        axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5);
        axs[row_idx,0].axis('off')
        axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5);
        axs[row_idx,1].axis('off')
    
    axs[0,0].title.set_text("Original")
    axs[0,1].title.set_text("Reconstruction")
    plt.tight_layout()
    fig.set_size_inches(10, 10.5)
    return fig

def make_subplot_latent(original, reconstructed):
    fig, axs = plt.subplots(3,2)
    for row_idx in range(3):
        axs[row_idx,0].matshow(original[row_idx].squeeze());
        axs[row_idx,0].axis('off')
        
        axs[row_idx,1].matshow(reconstructed[row_idx].squeeze());
        axs[row_idx,1].axis('off')
        for i in range(7):
            for j in range(7):
                c = reconstructed[row_idx][i,j]
                axs[row_idx,1].text(i, j, str(c), va='center', ha='center')
    
    axs[0,0].title.set_text("Original")
    axs[0,1].title.set_text("Discrete Latent Representation")
    plt.tight_layout()
    fig.set_size_inches(10, 10.5)
    return fig

def plot_sample(mode):
    sample = np.random.choice(x_test.shape[0], 3)
    test_images = x_test_scaled[sample]
    if mode=='Reconstruction':
        reconstructions_test = model.predict(test_images)
        return make_subplot_reconstruction(test_images, reconstructions_test)
    encoded_out = encoder.predict(test_images)
    encoded = encoded_out.reshape(-1, encoded_out.shape[-1])
    quant = vq_object.get_code_indices(encoded)
    quant = quant.numpy().reshape(encoded_out.shape[:-1])
    
    return make_subplot_latent(test_images, quant)

demo = gr.Blocks()

with demo:
    gr.Markdown("# Vector-Quantized Variational Autoencoders (VQ-VAE)")
    gr.Markdown("""This space is to demonstrate the use of VQ-VAEs. Similar to tradiitonal VAEs, VQ-VAEs try to create a useful latent representation.
     However, VQ-VAEs latent space is **discrete** rather than continuous. Below, we can view how well this model compresses and reconstructs MNIST digits, but more importantly, we can see a
     discretized latent representation. These discrete representations can then be paired with a network like PixelCNN to generate novel images.
     
     VQ-VAEs are one of the tools used by DALL-E and are some of the only models that perform on par with VAEs but with a discrete latent space.
     For more information check out this [paper](https://arxiv.org/abs/1711.00937) and
     [example](https://keras.io/examples/generative/vq_vae/).<br>
     Full Credits for this example go to [Sayak Paul](https://twitter.com/RisingSayak).<br>
     Model card can be found [here](https://huggingface.co/brendenc/VQ-VAE).<br>
     Demo by [Brenden Connors](https://www.linkedin.com/in/brenden-connors-6a0512195)""")

    with gr.Row():
        with gr.Column():
            with gr.Row():
                radio = gr.Radio(choices=['Reconstruction','Discrete Latent Representation'])
            with gr.Row():
                button = gr.Button('Run')
        with gr.Column():
            out =  gr.Plot()

    button.click(plot_sample, radio, out)

demo.launch()