File size: 4,843 Bytes
bb2bcd2
 
 
 
 
a13e610
bb2bcd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import tensorflow as tf
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np

model = tf.saved_model.load('VQ-VAE-Model')

class VectorQuantizer(tf.keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = (
            beta  # This parameter is best kept between [0.25, 2] as per the paper.
        )

        # Initialize the embeddings which we will quantize.
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization.
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer. You can learn more
        # about adding losses to different layers here:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
        # the original paper to get a handle on the formulation of the loss function.
        commitment_loss = self.beta * tf.reduce_mean(
            (tf.stop_gradient(quantized) - x) ** 2
        )
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings ** 2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices

vq_object = VectorQuantizer(64, 16)
embs = np.load('embeddings.npy')
vq_object.embeddings = embs
model = tf.keras.models.load_model('VQ-VAE', custom_objects={'vector_quantizer':vq_object})
encoder = model.layers[1]

#data load and preprocess
_, (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = np.expand_dims(x_test, -1)
x_test_scaled = (x_test / 255.0) - 0.5

def make_subplot_reconstruction(original, reconstructed):
    fig, axs = plt.subplots(3,2)
    for row_idx in range(3):
        axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5);
        axs[row_idx,0].axis('off')
        axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5);
        axs[row_idx,1].axis('off')
    
    axs[0,0].title.set_text("Original")
    axs[0,1].title.set_text("Reconstruction")
    plt.tight_layout()
    fig.set_size_inches(10, 10.5)
    return fig

def make_subplot_latent(original, reconstructed):
    fig, axs = plt.subplots(3,2)
    for row_idx in range(3):
        axs[row_idx,0].matshow(original[row_idx].squeeze());
        axs[row_idx,0].axis('off')
        
        axs[row_idx,1].matshow(reconstructed[row_idx].squeeze());
        axs[row_idx,1].axis('off')
        for i in range(7):
            for j in range(7):
                c = reconstructed[row_idx][i,j]
                axs[row_idx,1].text(i, j, str(c), va='center', ha='center')
    
    axs[0,0].title.set_text("Original")
    axs[0,1].title.set_text("Discrete Latent Representation")
    plt.tight_layout()
    fig.set_size_inches(10, 10.5)
    return fig

def plot_sample(mode):
    sample = np.random.choice(x_test.shape[0], 3)
    test_images = x_test_scaled[sample]
    if mode=='Reconstruction':
        reconstructions_test = model.predict(test_images)
        return make_subplot_reconstruction(test_images, reconstructions_test)
    encoded_out = encoder.predict(test_images)
    encoded = encoded_out.reshape(-1, encoded_out.shape[-1])
    quant = vq_object.get_code_indices(encoded)
    quant = quant.numpy().reshape(encoded_out.shape[:-1])
    
    return make_subplot_latent(test_images, quant)

import gradio as gr
radio = gr.Radio(choices=['Reconstruction','Latent Representation'])
out =  gr.Plot()

gr.Interface(plot_sample, radio, out).launch()