brendenc commited on
Commit
bb2bcd2
1 Parent(s): 4406647

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+ model = tf.saved_model.load('VQ-VAE')
7
+
8
+ class VectorQuantizer(tf.keras.layers.Layer):
9
+ def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.embedding_dim = embedding_dim
12
+ self.num_embeddings = num_embeddings
13
+ self.beta = (
14
+ beta # This parameter is best kept between [0.25, 2] as per the paper.
15
+ )
16
+
17
+ # Initialize the embeddings which we will quantize.
18
+ w_init = tf.random_uniform_initializer()
19
+ self.embeddings = tf.Variable(
20
+ initial_value=w_init(
21
+ shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
22
+ ),
23
+ trainable=True,
24
+ name="embeddings_vqvae",
25
+ )
26
+
27
+ def call(self, x):
28
+ # Calculate the input shape of the inputs and
29
+ # then flatten the inputs keeping `embedding_dim` intact.
30
+ input_shape = tf.shape(x)
31
+ flattened = tf.reshape(x, [-1, self.embedding_dim])
32
+
33
+ # Quantization.
34
+ encoding_indices = self.get_code_indices(flattened)
35
+ encodings = tf.one_hot(encoding_indices, self.num_embeddings)
36
+ quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
37
+ quantized = tf.reshape(quantized, input_shape)
38
+
39
+ # Calculate vector quantization loss and add that to the layer. You can learn more
40
+ # about adding losses to different layers here:
41
+ # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
42
+ # the original paper to get a handle on the formulation of the loss function.
43
+ commitment_loss = self.beta * tf.reduce_mean(
44
+ (tf.stop_gradient(quantized) - x) ** 2
45
+ )
46
+ codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
47
+ self.add_loss(commitment_loss + codebook_loss)
48
+
49
+ # Straight-through estimator.
50
+ quantized = x + tf.stop_gradient(quantized - x)
51
+ return quantized
52
+
53
+ def get_code_indices(self, flattened_inputs):
54
+ # Calculate L2-normalized distance between the inputs and the codes.
55
+ similarity = tf.matmul(flattened_inputs, self.embeddings)
56
+ distances = (
57
+ tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
58
+ + tf.reduce_sum(self.embeddings ** 2, axis=0)
59
+ - 2 * similarity
60
+ )
61
+
62
+ # Derive the indices for minimum distances.
63
+ encoding_indices = tf.argmin(distances, axis=1)
64
+ return encoding_indices
65
+
66
+ vq_object = VectorQuantizer(64, 16)
67
+ embs = np.load('embeddings.npy')
68
+ vq_object.embeddings = embs
69
+ model = tf.keras.models.load_model('VQ-VAE', custom_objects={'vector_quantizer':vq_object})
70
+ encoder = model.layers[1]
71
+
72
+ #data load and preprocess
73
+ _, (x_test, _) = tf.keras.datasets.mnist.load_data()
74
+ x_test = np.expand_dims(x_test, -1)
75
+ x_test_scaled = (x_test / 255.0) - 0.5
76
+
77
+ def make_subplot_reconstruction(original, reconstructed):
78
+ fig, axs = plt.subplots(3,2)
79
+ for row_idx in range(3):
80
+ axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5);
81
+ axs[row_idx,0].axis('off')
82
+ axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5);
83
+ axs[row_idx,1].axis('off')
84
+
85
+ axs[0,0].title.set_text("Original")
86
+ axs[0,1].title.set_text("Reconstruction")
87
+ plt.tight_layout()
88
+ fig.set_size_inches(10, 10.5)
89
+ return fig
90
+
91
+ def make_subplot_latent(original, reconstructed):
92
+ fig, axs = plt.subplots(3,2)
93
+ for row_idx in range(3):
94
+ axs[row_idx,0].matshow(original[row_idx].squeeze());
95
+ axs[row_idx,0].axis('off')
96
+
97
+ axs[row_idx,1].matshow(reconstructed[row_idx].squeeze());
98
+ axs[row_idx,1].axis('off')
99
+ for i in range(7):
100
+ for j in range(7):
101
+ c = reconstructed[row_idx][i,j]
102
+ axs[row_idx,1].text(i, j, str(c), va='center', ha='center')
103
+
104
+ axs[0,0].title.set_text("Original")
105
+ axs[0,1].title.set_text("Discrete Latent Representation")
106
+ plt.tight_layout()
107
+ fig.set_size_inches(10, 10.5)
108
+ return fig
109
+
110
+ def plot_sample(mode):
111
+ sample = np.random.choice(x_test.shape[0], 3)
112
+ test_images = x_test_scaled[sample]
113
+ if mode=='Reconstruction':
114
+ reconstructions_test = model.predict(test_images)
115
+ return make_subplot_reconstruction(test_images, reconstructions_test)
116
+ encoded_out = encoder.predict(test_images)
117
+ encoded = encoded_out.reshape(-1, encoded_out.shape[-1])
118
+ quant = vq_object.get_code_indices(encoded)
119
+ quant = quant.numpy().reshape(encoded_out.shape[:-1])
120
+
121
+ return make_subplot_latent(test_images, quant)
122
+
123
+ import gradio as gr
124
+ radio = gr.Radio(choices=['Reconstruction','Latent Representation'])
125
+ out = gr.Plot()
126
+
127
+ gr.Interface(plot_sample, radio, out).launch()