KublaiKhan1 commited on
Commit
59c7686
·
verified ·
1 Parent(s): 51d8866

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -124,3 +124,7 @@ f32c4_down/checkpoint/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
124
  f32c4_down/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
125
  f32c4_down/checkpointbest.tmp/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
126
  f32c4_down/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
124
  f32c4_down/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
125
  f32c4_down/checkpointbest.tmp/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
126
  f32c4_down/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
127
+ f32c32_full/checkpoint/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
128
+ f32c32_full/checkpoint.tmp filter=lfs diff=lfs merge=lfs -text
129
+ f32c32_full/checkpointbest.tmp/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
130
+ f32c32_full/checkpointbest.tmp.tmp filter=lfs diff=lfs merge=lfs -text
f32c32_full/checkpoint.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:640d4065303f113997eb89cface839503937a7827652069bd0c82c2a101ed6a9
3
+ size 1676176195
f32c32_full/checkpoint/checkpoint.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29a85dabf5004a4405402384a3b1a0c0ee81894d022ba01cc726ed8f4254a9e0
3
+ size 1676176195
f32c32_full/checkpointbest.tmp.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79318fec959fe175e33826cab87069568360411d0b915f84f3fc94d8f4b6db75
3
+ size 1676176195
f32c32_full/checkpointbest.tmp/checkpointbest.tmp.tmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd482a1de53f1ebc027b706f0a1bf3ca1a8ffbaa26f59d138161987dcaa49174
3
+ size 1676176195
f32c32_full/train.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try: # For debugging
2
+ from localutils.debugger import enable_debug
3
+ enable_debug()
4
+ except ImportError:
5
+ pass
6
+
7
+ import flax.linen as nn
8
+ import jax.numpy as jnp
9
+ from absl import app, flags
10
+ from functools import partial
11
+ import numpy as np
12
+ import tqdm
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import flax
16
+ import optax
17
+ import wandb
18
+ from ml_collections import config_flags
19
+ import ml_collections
20
+ import tensorflow_datasets as tfds
21
+ import tensorflow as tf
22
+ tf.config.set_visible_devices([], "GPU")
23
+ tf.config.set_visible_devices([], "TPU")
24
+ import matplotlib.pyplot as plt
25
+ from typing import Any
26
+ import os
27
+
28
+ from utils.wandb import setup_wandb, default_wandb_config
29
+ from utils.train_state import TrainState, target_update
30
+ from utils.checkpoint import Checkpoint
31
+ from utils.pretrained_resnet import get_pretrained_embs, get_pretrained_model
32
+ from utils.fid import get_fid_network, fid_from_stats
33
+ from models.vqvae import VQVAE
34
+ from models.discriminator import Discriminator
35
+
36
+ FLAGS = flags.FLAGS
37
+ flags.DEFINE_string('dataset_name', 'imagenet256', 'Environment name.')
38
+ flags.DEFINE_string('save_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint", 'Save dir (if not None, save params).')
39
+ flags.DEFINE_string('load_dir', "/home/lambda/jax-vqvae-vqgan/chkpts/checkpoint.tmp" , 'Load dir (if not None, load params from here).')
40
+ flags.DEFINE_integer('seed', 0, 'Random seed.')
41
+ flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
42
+ flags.DEFINE_integer('eval_interval', 1000, 'Eval interval.')
43
+ flags.DEFINE_integer('save_interval', 1000, 'Save interval.')
44
+ flags.DEFINE_integer('batch_size', 64, 'Total Batch size.')
45
+ flags.DEFINE_integer('max_steps', int(1_000_000), 'Number of training steps.')
46
+
47
+ model_config = ml_collections.ConfigDict({
48
+ # VQVAE
49
+ 'lr': 0.0001,
50
+ 'beta1': 0.0,#.5
51
+ 'beta2': 0.99,#.9
52
+ 'lr_warmup_steps': 2000,
53
+ 'lr_decay_steps': 500_000,#They use 'lambdalr'
54
+ 'filters': 128,
55
+ 'num_res_blocks': 2,
56
+ 'channel_multipliers': (1, 1, 2, 2, 4, 4),#We want 5 blocks for downsample 4x
57
+ 'embedding_dim': 32, # For FSQ, a good default is 4.
58
+ 'norm_type': 'GN',
59
+ 'weight_decay': 0.05,#None maybe?
60
+ 'clip_gradient': 1.0,
61
+ 'l2_loss_weight': 1.0,#They use L1 actually
62
+ 'eps_update_rate': 0.9999,
63
+ # Quantizer
64
+ 'quantizer_type': 'vq', # or 'fsq', 'kl'
65
+ # Quantizer (VQ)
66
+ 'quantizer_loss_ratio': 1,
67
+ 'codebook_size': 1024,
68
+ 'entropy_loss_ratio': 0.1,
69
+ 'entropy_loss_type': 'softmax',
70
+ 'entropy_temperature': 0.01,
71
+ 'commitment_cost': 0.25,
72
+ # Quantizer (FSQ)
73
+ 'fsq_levels': 5, # Bins per dimension.
74
+ # Quantizer (KL)
75
+ 'kl_weight': 0.000001,#They use 1e-6 on their stuff LUL. .001 is the default
76
+ # GAN
77
+ 'g_adversarial_loss_weight': 0.5,
78
+ 'g_grad_penalty_cost': 10,
79
+ 'perceptual_loss_weight': 0.5,
80
+ 'gan_warmup_steps': 25000,
81
+ })
82
+
83
+ wandb_config = default_wandb_config()
84
+ wandb_config.update({
85
+ 'project': 'vqvae',
86
+ 'name': 'vqvae_{dataset_name}',
87
+ })
88
+
89
+ config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
90
+ config_flags.DEFINE_config_dict('model', model_config, lock_config=False)
91
+
92
+ ##############################################
93
+ ## Model Definitions.
94
+ ##############################################
95
+
96
+ @jax.vmap
97
+ def sigmoid_cross_entropy_with_logits(*, labels: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
98
+ """https://github.com/google-research/maskgit/blob/main/maskgit/libml/losses.py
99
+ """
100
+ zeros = jnp.zeros_like(logits, dtype=logits.dtype)
101
+ condition = (logits >= zeros)
102
+ relu_logits = jnp.where(condition, logits, zeros)
103
+ neg_abs_logits = jnp.where(condition, -logits, logits)
104
+ return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
105
+
106
+ class VQGANModel(flax.struct.PyTreeNode):
107
+ rng: Any
108
+ config: dict = flax.struct.field(pytree_node=False)
109
+ vqvae: TrainState
110
+ vqvae_eps: TrainState
111
+ discriminator: TrainState
112
+
113
+ # Train G and D.
114
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
115
+ def update(self, images, pmap_axis='data'):
116
+ new_rng, curr_key = jax.random.split(self.rng, 2)
117
+
118
+ resnet, resnet_params = get_pretrained_model('resnet50', 'data/resnet_pretrained.npy')
119
+
120
+ is_gan_training = 1.0 - (self.vqvae.step < self.config['gan_warmup_steps']).astype(jnp.float32)
121
+
122
+ def loss_fn(params_vqvae, params_disc):
123
+ # Reconstruct image
124
+ reconstructed_images, result_dict = self.vqvae(images, params=params_vqvae, rngs={'noise': curr_key})
125
+ print("Reconstructed images shape", reconstructed_images.shape)
126
+ print("Input images shape", images.shape)
127
+ assert reconstructed_images.shape == images.shape
128
+
129
+ # GAN loss on VQVAE output.
130
+ discriminator_fn = lambda x: self.discriminator(x, params=params_disc)
131
+ real_logit, vjp_fn = jax.vjp(discriminator_fn, images, has_aux=False)
132
+ gradient = vjp_fn(jnp.ones_like(real_logit))[0] # Gradient of discriminator output wrt. real images.
133
+ gradient = gradient.reshape((images.shape[0], -1))
134
+ gradient = jnp.asarray(gradient, jnp.float32)
135
+ penalty = jnp.sum(jnp.square(gradient), axis=-1)
136
+ penalty = jnp.mean(penalty) # Gradient penalty for training D.
137
+ fake_logit = discriminator_fn(reconstructed_images)
138
+ d_loss_real = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(real_logit), logits=real_logit).mean()
139
+ d_loss_fake = sigmoid_cross_entropy_with_logits(labels=jnp.zeros_like(fake_logit), logits=fake_logit).mean()
140
+ loss_d = d_loss_real + d_loss_fake + (penalty * self.config['g_grad_penalty_cost'])
141
+
142
+ d_loss_for_vae = sigmoid_cross_entropy_with_logits(labels=jnp.ones_like(fake_logit), logits=fake_logit).mean()
143
+ d_loss_for_vae = d_loss_for_vae * is_gan_training
144
+
145
+ real_pools, _ = get_pretrained_embs(resnet_params, resnet, images=images)
146
+ fake_pools, _ = get_pretrained_embs(resnet_params, resnet, images=reconstructed_images)
147
+ perceptual_loss = jnp.mean((real_pools - fake_pools)**2)
148
+
149
+ l2_loss = jnp.mean((reconstructed_images - images) ** 2)
150
+ quantizer_loss = result_dict['quantizer_loss'] if 'quantizer_loss' in result_dict else 0.0
151
+ if self.config['quantizer_type'] == 'kl' or self.config["quantizer_type"] == "kl_two":
152
+ quantizer_loss = quantizer_loss * self.config['kl_weight']
153
+ loss_vae = (l2_loss * FLAGS.model['l2_loss_weight']) \
154
+ + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
155
+ + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
156
+ + (perceptual_loss * FLAGS.model['perceptual_loss_weight'])
157
+ codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
158
+ return (loss_vae, loss_d), {
159
+ 'loss_vae': loss_vae,
160
+ 'loss_d': loss_d,
161
+ 'l2_loss': l2_loss,
162
+ 'd_loss_for_vae': d_loss_for_vae,
163
+ 'perceptual_loss': perceptual_loss,
164
+ 'quantizer_loss': quantizer_loss,
165
+ 'codebook_usage': codebook_usage,
166
+ }
167
+
168
+ # This is a fancy way to do 'jax.grad' so (loss_vae, params_vqvae) and (loss_d, params_disc) are differentiated.
169
+ _, grad_fn, info = jax.vjp(loss_fn, self.vqvae.params, self.discriminator.params, has_aux=True)
170
+ vae_grads, _ = grad_fn((1., 0.))
171
+ _, d_grads = grad_fn((0., 1.))
172
+
173
+ vae_grads = jax.lax.pmean(vae_grads, axis_name=pmap_axis)
174
+ d_grads = jax.lax.pmean(d_grads, axis_name=pmap_axis)
175
+ d_grads = jax.tree_map(lambda x: x * is_gan_training, d_grads)
176
+
177
+ info = jax.lax.pmean(info, axis_name=pmap_axis)
178
+ if self.config['quantizer_type'] == 'fsq':
179
+ info['codebook_usage'] = jnp.sum(info['codebook_usage'] > 0) / info['codebook_usage'].shape[-1]
180
+
181
+ updates, new_opt_state = self.vqvae.tx.update(vae_grads, self.vqvae.opt_state, self.vqvae.params)
182
+ new_params = optax.apply_updates(self.vqvae.params, updates)
183
+ new_vqvae = self.vqvae.replace(step=self.vqvae.step + 1, params=new_params, opt_state=new_opt_state)
184
+
185
+ updates, new_opt_state = self.discriminator.tx.update(d_grads, self.discriminator.opt_state, self.discriminator.params)
186
+ new_params = optax.apply_updates(self.discriminator.params, updates)
187
+ new_discriminator = self.discriminator.replace(step=self.discriminator.step + 1, params=new_params, opt_state=new_opt_state)
188
+
189
+ info['grad_norm_vae'] = optax.global_norm(vae_grads)
190
+ info['grad_norm_d'] = optax.global_norm(d_grads)
191
+ info['update_norm'] = optax.global_norm(updates)
192
+ info['param_norm'] = optax.global_norm(new_params)
193
+ info['is_gan_training'] = is_gan_training
194
+
195
+ new_vqvae_eps = target_update(new_vqvae, self.vqvae_eps, 1-self.config['eps_update_rate'])
196
+
197
+ new_model = self.replace(rng=new_rng, vqvae=new_vqvae, vqvae_eps=new_vqvae_eps, discriminator=new_discriminator)
198
+ return new_model, info
199
+
200
+ @partial(jax.pmap, axis_name='data', in_axes=(0, 0))
201
+ def reconstruction(self, images, pmap_axis='data', sampling = False):
202
+ if not sampling:
203
+ reconstructed_images, _ = self.vqvae_eps(images)
204
+ else:
205
+ new_rng, curr_key = jax.random.split(self.rng, 2)
206
+ reconstructed_images, _ = self.vqvae(images, rngs={'noise': curr_key})
207
+
208
+ reconstructed_images = jnp.clip(reconstructed_images, 0, 1)
209
+ return reconstructed_images
210
+
211
+ ##############################################
212
+ ## Training Code.
213
+ ##############################################
214
+ def main(_):
215
+ np.random.seed(FLAGS.seed)
216
+ print("Using devices", jax.local_devices())
217
+ device_count = len(jax.local_devices())
218
+ global_device_count = jax.device_count()
219
+ local_batch_size = FLAGS.batch_size // (global_device_count // device_count)
220
+ print("Device count", device_count)
221
+ print("Global device count", global_device_count)
222
+ print("Global Batch: ", FLAGS.batch_size)
223
+ print("Node Batch: ", local_batch_size)
224
+ print("Device Batch:", local_batch_size // device_count)
225
+
226
+ # Create wandb logger
227
+ if jax.process_index() == 0:
228
+ setup_wandb(FLAGS.model.to_dict(), **FLAGS.wandb)
229
+
230
+ def get_dataset(is_train):
231
+ if 'imagenet' in FLAGS.dataset_name:
232
+ def deserialization_fn(data):
233
+ image = data['image']
234
+ min_side = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
235
+ image = tf.image.resize_with_crop_or_pad(image, min_side, min_side)
236
+ if 'imagenet256' in FLAGS.dataset_name:
237
+ image = tf.image.resize(image, (256, 256))
238
+ elif 'imagenet128' in FLAGS.dataset_name:
239
+ image = tf.image.resize(image, (128, 128))
240
+ else:
241
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
242
+ if is_train:
243
+ image = tf.image.random_flip_left_right(image)
244
+ image = tf.cast(image, tf.float32) / 255.0
245
+ return image
246
+
247
+
248
+ split = tfds.split_for_jax_process('train' if is_train else 'validation', drop_remainder=True)
249
+ print(split)
250
+ dataset = tfds.load('imagenet2012', split=split, data_dir = "/dev/shm")
251
+ dataset = dataset.map(deserialization_fn, num_parallel_calls=tf.data.AUTOTUNE)
252
+ dataset = dataset.shuffle(10000, seed=42, reshuffle_each_iteration=True)
253
+ dataset = dataset.repeat()
254
+ dataset = dataset.batch(local_batch_size)
255
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
256
+ dataset = tfds.as_numpy(dataset)
257
+ dataset = iter(dataset)
258
+ return dataset
259
+ else:
260
+ raise ValueError(f"Unknown dataset {FLAGS.dataset_name}")
261
+
262
+ dataset = get_dataset(is_train=True)
263
+ dataset_valid = get_dataset(is_train=False)
264
+ example_obs = next(dataset)[:1]
265
+
266
+ get_fid_activations = get_fid_network()
267
+ if not os.path.exists('./data/imagenet256_fidstats_openai.npz'):
268
+ raise ValueError("Please download the FID stats file! See the README.")
269
+ # truth_fid_stats = np.load('data/imagenet256_fidstats_openai.npz')
270
+ truth_fid_stats = np.load("./base_stats.npz")
271
+
272
+ rng = jax.random.PRNGKey(FLAGS.seed)
273
+ rng, param_key = jax.random.split(rng)
274
+ print("Total Memory on device:", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, "GB")
275
+
276
+ ###################################
277
+ # Creating Model and put on devices.
278
+ ###################################
279
+ FLAGS.model.image_channels = example_obs.shape[-1]
280
+ FLAGS.model.image_size = example_obs.shape[1]
281
+ vqvae_def = VQVAE(FLAGS.model, train=True)
282
+ vqvae_params = vqvae_def.init({'params': param_key, 'noise': param_key}, example_obs)['params']
283
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
284
+ vqvae_ts = TrainState.create(vqvae_def, vqvae_params, tx=tx)
285
+ vqvae_def_eps = VQVAE(FLAGS.model, train=False)
286
+ vqvae_eps_ts = TrainState.create(vqvae_def_eps, vqvae_params)
287
+ print("Total num of VQVAE parameters:", sum(x.size for x in jax.tree_util.tree_leaves(vqvae_params)))
288
+
289
+ discriminator_def = Discriminator(FLAGS.model)
290
+ discriminator_params = discriminator_def.init(param_key, example_obs)['params']
291
+ tx = optax.adam(learning_rate=FLAGS.model['lr'], b1=FLAGS.model['beta1'], b2=FLAGS.model['beta2'])
292
+ discriminator_ts = TrainState.create(discriminator_def, discriminator_params, tx=tx)
293
+ print("Total num of Discriminator parameters:", sum(x.size for x in jax.tree_util.tree_leaves(discriminator_params)))
294
+
295
+ model = VQGANModel(rng=rng, vqvae=vqvae_ts, vqvae_eps=vqvae_eps_ts, discriminator=discriminator_ts, config=FLAGS.model)
296
+
297
+ if FLAGS.load_dir is not None:
298
+ try:
299
+ cp = Checkpoint(FLAGS.load_dir)
300
+ model = cp.load_model(model)
301
+ print("Loaded model with step", model.vqvae.step)
302
+ except:
303
+ print("Random init")
304
+ else:
305
+ print("Random init")
306
+
307
+ model = flax.jax_utils.replicate(model, devices=jax.local_devices())
308
+ jax.debug.visualize_array_sharding(model.vqvae.params['decoder']['Conv_0']['bias'])
309
+
310
+ ###################################
311
+ # Train Loop
312
+ ###################################
313
+
314
+ best_fid = 100000
315
+
316
+ for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
317
+ smoothing=0.1,
318
+ dynamic_ncols=True):
319
+
320
+ batch_images = next(dataset)
321
+ batch_images = batch_images.reshape((len(jax.local_devices()), -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]
322
+
323
+ model, update_info = model.update(batch_images)
324
+
325
+ if i % FLAGS.log_interval == 0:
326
+ update_info = jax.tree_map(lambda x: x.mean(), update_info)
327
+ train_metrics = {f'training/{k}': v for k, v in update_info.items()}
328
+ if jax.process_index() == 0:
329
+ wandb.log(train_metrics, step=i)
330
+
331
+ if i % FLAGS.eval_interval == 0:
332
+ # Print some images
333
+ reconstructed_images = model.reconstruction(batch_images) # [devices, 8, 256, 256, 3]
334
+ valid_images = next(dataset_valid)
335
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
336
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
337
+
338
+ if jax.process_index() == 0:
339
+ wandb.log({'batch_image_mean': batch_images.mean()}, step=i)
340
+ wandb.log({'reconstructed_images_mean': reconstructed_images.mean()}, step=i)
341
+ wandb.log({'batch_image_std': batch_images.std()}, step=i)
342
+ wandb.log({'reconstructed_images_std': reconstructed_images.std()}, step=i)
343
+
344
+ # plot comparison witah matplotlib. put each reconstruction side by side.
345
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
346
+ #print("batch shape", batch_images.shape)#batch shape (4, 32, 256, 256, 3) #THE FIRST SHAPE IS DEVICES
347
+ #print("recon shape", reconstructed_images.shape)#it's all the same lol
348
+ #print("valid shape", valid_images.shape)
349
+ #it seems to be made for 8 device, aka tpuv3 instead
350
+ for j in range(4):#fuck it
351
+ axs[0, j].imshow(batch_images[j, 0], vmin=0, vmax=1)
352
+ axs[1, j].imshow(reconstructed_images[j, 0], vmin=0, vmax=1)
353
+ wandb.log({'reconstruction': wandb.Image(fig)}, step=i)
354
+ plt.close(fig)
355
+ fig, axs = plt.subplots(2, 8, figsize=(30, 15))
356
+ for j in range(4):
357
+ axs[0, j].imshow(valid_images[j, 0], vmin=0, vmax=1)
358
+ axs[1, j].imshow(valid_reconstructed_images[j, 0], vmin=0, vmax=1)
359
+ wandb.log({'reconstruction_valid': wandb.Image(fig)}, step=i)
360
+ plt.close(fig)
361
+
362
+ # Validation Losses
363
+ _, valid_update_info = model.update(valid_images)
364
+ valid_update_info = jax.tree_map(lambda x: x.mean(), valid_update_info)
365
+ valid_metrics = {f'validation/{k}': v for k, v in valid_update_info.items()}
366
+ if jax.process_index() == 0:
367
+ wandb.log(valid_metrics, step=i)
368
+
369
+ # FID measurement.
370
+ activations = []
371
+ activations2 = []
372
+ for _ in range(780):#This is apprximately 40k
373
+ valid_images = next(dataset_valid)
374
+ valid_images = valid_images.reshape((len(jax.local_devices()), -1, *valid_images.shape[1:])) # [devices, batch//devices, etc..]
375
+ valid_reconstructed_images = model.reconstruction(valid_images) # [devices, 8, 256, 256, 3]
376
+
377
+ valid_reconstructed_images = jax.image.resize(valid_reconstructed_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
378
+ method='bilinear', antialias=False)
379
+ valid_reconstructed_images = 2 * valid_reconstructed_images - 1
380
+ activations += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
381
+
382
+
383
+ #Only needed when we save
384
+ #valid_reconstructed_images = jax.image.resize(valid_images, (valid_images.shape[0], valid_images.shape[1], 299, 299, 3),
385
+ #method='bilinear', antialias=False)
386
+ #valid_reconstructed_images = 2 * valid_reconstructed_images - 1
387
+ #activations2 += [np.array(get_fid_activations(valid_reconstructed_images))[..., 0, 0, :]]
388
+
389
+
390
+ # TODO: use all_gather to get activations from all devices.
391
+ #This seems to be FID with only 64 images?
392
+ activations = np.concatenate(activations, axis=0)
393
+ activations = activations.reshape((-1, activations.shape[-1]))
394
+
395
+ # activations2 = np.concatenate(activations2, axis = 0)
396
+ # activations2 = activations2.reshape((-1, activations2.shape[-1]))
397
+
398
+ print("doing this much FID", activations.shape)#8192, 2048 should be 2048 items then I guess
399
+ mu1 = np.mean(activations, axis=0)
400
+ sigma1 = np.cov(activations, rowvar=False)
401
+ fid = fid_from_stats(mu1, sigma1, truth_fid_stats['mu'], truth_fid_stats['sigma'])
402
+
403
+ # mu2 = np.mean(activations2, axis = 0)
404
+ # sigma2 = np.cov(activations2, rowvar = False)
405
+
406
+ #save mu2 and sigma2
407
+ #And then exit for now
408
+ # np.savez("base.npz", mu = mu2, sigma = sigma2)
409
+ # exit()
410
+
411
+ #Used with loading base
412
+ #fid = fid_from_stats(mu1, sigma1, mu2, sigma2)
413
+
414
+ if jax.process_index() == 0:
415
+ wandb.log({'validation/fid': fid}, step=i)
416
+ print("validation FID at step", i, fid)
417
+ #Then if fid is smaller than previous best FID, save new FID
418
+ if fid < best_fid:
419
+ model_single = flax.jax_utils.unreplicate(model)
420
+ cp = Checkpoint(FLAGS.save_dir + "best.tmp")
421
+ cp.set_model(model_single)
422
+ cp.save()
423
+ best_fid = fid
424
+
425
+ if (i % FLAGS.save_interval == 0) and (FLAGS.save_dir is not None):
426
+ if jax.process_index() == 0:
427
+ model_single = flax.jax_utils.unreplicate(model)
428
+ cp = Checkpoint(FLAGS.save_dir)
429
+ cp.set_model(model_single)
430
+ cp.save()
431
+
432
+ if __name__ == '__main__':
433
+ app.run(main)
f32c32_full/vqvae.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import flax.linen as nn
3
+ import jax.numpy as jnp
4
+ import functools
5
+ import ml_collections
6
+ import jax
7
+
8
+ ###########################
9
+ ### Helper Modules
10
+ ### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
11
+ ###########################
12
+
13
+ def get_norm_layer(norm_type):
14
+ """Normalization layer."""
15
+ if norm_type == 'BN':
16
+ raise NotImplementedError
17
+ elif norm_type == 'LN':
18
+ norm_fn = functools.partial(nn.LayerNorm)
19
+ elif norm_type == 'GN':
20
+ norm_fn = functools.partial(nn.GroupNorm)
21
+ else:
22
+ raise NotImplementedError
23
+ return norm_fn
24
+
25
+
26
+ def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
27
+ pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
28
+ (1,) + window_shape + (1,),
29
+ (1,) + strides + (1,), padding)
30
+ pool_denom = jax.lax.reduce_window(
31
+ jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
32
+ (1,) + strides + (1,), padding)
33
+ return pool_sum / pool_denom
34
+
35
+ def upsample(x, factor=2):
36
+ n, h, w, c = x.shape
37
+ x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
38
+ return x
39
+
40
+ def dsample(x):
41
+ return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
42
+
43
+ def squared_euclidean_distance(a: jnp.ndarray,
44
+ b: jnp.ndarray,
45
+ b2: jnp.ndarray = None) -> jnp.ndarray:
46
+ """Computes the pairwise squared Euclidean distance.
47
+
48
+ Args:
49
+ a: float32: (n, d): An array of points.
50
+ b: float32: (m, d): An array of points.
51
+ b2: float32: (d, m): b square transpose.
52
+
53
+ Returns:
54
+ d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
55
+ a[i] and b[j].
56
+ """
57
+ if b2 is None:
58
+ b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
59
+ a2 = jnp.sum(a**2, axis=1, keepdims=True)
60
+ ab = jnp.matmul(a, b.T)
61
+ d = a2 - 2 * ab + b2
62
+ return d
63
+
64
+ def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
65
+ """Calculates the entropy loss. Affinity is the similarity/distance matrix."""
66
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
67
+ flat_affinity /= temperature
68
+ probs = jax.nn.softmax(flat_affinity, axis=-1)
69
+ log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
70
+ if loss_type == "softmax":
71
+ target_probs = probs
72
+ elif loss_type == "argmax":
73
+ codes = jnp.argmax(flat_affinity, axis=-1)
74
+ onehots = jax.nn.one_hot(
75
+ codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
76
+ onehots = probs - jax.lax.stop_gradient(probs - onehots)
77
+ target_probs = onehots
78
+ else:
79
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
80
+ avg_probs = jnp.mean(target_probs, axis=0)
81
+ avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
82
+ sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
83
+ loss = sample_entropy - avg_entropy
84
+ return loss
85
+
86
+ def sg(x):
87
+ return jax.lax.stop_gradient(x)
88
+
89
+
90
+
91
+
92
+ ###########################
93
+ ### Modules
94
+ ###########################
95
+
96
+ class ResBlock(nn.Module):
97
+ """Basic Residual Block."""
98
+ filters: int
99
+ norm_fn: Any
100
+ activation_fn: Any
101
+
102
+ @nn.compact
103
+ def __call__(self, x):
104
+ input_dim = x.shape[-1]
105
+ residual = x
106
+ x = self.norm_fn()(x)
107
+ x = self.activation_fn(x)
108
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
109
+ x = self.norm_fn()(x)
110
+ x = self.activation_fn(x)
111
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
112
+
113
+ if input_dim != self.filters:#Basically if input doesn't match output, use a skip
114
+ residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
115
+ return x + residual
116
+
117
+ class Encoder(nn.Module):
118
+ """From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
119
+ config: ml_collections.ConfigDict
120
+
121
+ def setup(self):
122
+ self.filters = self.config.filters#filters is the original setup
123
+ self.num_res_blocks = self.config.num_res_blocks
124
+ self.channel_multipliers = self.config.channel_multipliers
125
+ self.embedding_dim = self.config.embedding_dim
126
+ self.norm_type = self.config.norm_type
127
+ self.activation_fn = nn.swish
128
+
129
+ def pixels(self, x):
130
+ #print("pixel shuffle x shape", x.shape)
131
+ x = pixel_unshuffle(x, 2)
132
+ #print(x.shape)
133
+ B, H, W, C = x.shape
134
+ x = jnp.reshape(x, (B, H, W, int(C/4), 4))
135
+ #print(x.shape)
136
+ x = jnp.mean(x, axis = -1)
137
+ #print(x.shape)
138
+ #exit()
139
+ return x
140
+
141
+
142
+ @nn.compact
143
+ def __call__(self, x):
144
+ print("Initializing encoder.")
145
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
146
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
147
+ print("Incoming encoder shape", x.shape)
148
+ x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
149
+ print('Encoder layer', x.shape)
150
+ num_blocks = len(self.channel_multipliers)
151
+
152
+ #The way SD works, is it does 2x resnet, not changing anything, then downsample
153
+ #It does this 3 times, leading to 8x downsample
154
+ #Then it has an extra resnet block, and THEN from 512 to 8 / 4
155
+
156
+ #So the DCAE architecture is like 4x resnet, down
157
+ #And then efficient vit down
158
+ for i in range(num_blocks):
159
+ filters = self.filters * self.channel_multipliers[i]
160
+ for _ in range(self.num_res_blocks):
161
+ x = ResBlock(filters, **block_args)(x)
162
+
163
+ if i < num_blocks - 1:#For each block *except end* do downsample
164
+ print("doing downsample")
165
+ #If we want to do it DCAE style, they do channel averaging between before downsample and after
166
+ if self.channel_multipliers[i] != -1:
167
+ print("pre pixels", x.shape)
168
+ pixel_x = self.pixels(x)
169
+ print("pixel_x", pixel_x.shape)
170
+ x = dsample(x) + pixel_x
171
+ print("post", x.shape)
172
+ else:
173
+ x = dsample(x)
174
+ print("other post", x.shape)
175
+
176
+ print('Encoder layer', x.shape)
177
+
178
+ #After we are done downsampling, we do the 2 resnet, and down below here, we have the 2 midblock?
179
+
180
+ for _ in range(self.num_res_blocks):
181
+ x = ResBlock(filters, **block_args)(x)
182
+ print('Encoder layer final', x.shape)
183
+
184
+ x = norm_fn()(x)
185
+ x = self.activation_fn(x)
186
+ last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
187
+ x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
188
+ print("Final embeddings are size", x.shape)
189
+ return x
190
+
191
+ class Decoder(nn.Module):
192
+ """From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
193
+
194
+ config: ml_collections.ConfigDict
195
+
196
+ def setup(self):
197
+ self.filters = self.config.filters
198
+ self.num_res_blocks = self.config.num_res_blocks
199
+ self.channel_multipliers = self.config.channel_multipliers
200
+ self.norm_type = self.config.norm_type
201
+ self.image_channels = self.config.image_channels
202
+ self.activation_fn = nn.swish
203
+
204
+ def pixels(self, x):
205
+ print("pixels shape", x.shape)
206
+ x = jnp.repeat(x, 4, axis = -1)
207
+ print(x.shape)
208
+ x = pixel_shuffle(x, 2)
209
+ print(x.shape)
210
+ print("done duplicating")
211
+ return x
212
+
213
+ @nn.compact
214
+ def __call__(self, x):
215
+ norm_fn = get_norm_layer(norm_type=self.norm_type)
216
+ block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
217
+ num_blocks = len(self.channel_multipliers)
218
+ filters = self.filters * self.channel_multipliers[-1]
219
+ print("Decoder incoming shape", x.shape)
220
+
221
+ #We don't need to do anything here because it'll put it back to 512
222
+
223
+ x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
224
+ print("Decoder input", x.shape)
225
+
226
+
227
+ #This is the mid block
228
+ for _ in range(self.num_res_blocks):
229
+ x = ResBlock(filters, **block_args)(x)
230
+ print('Mid Block Decoder layer', x.shape)
231
+
232
+ #First two SET of blocks is just 3 resnet, no channel changes, we are already at 4x = 512
233
+
234
+ for i in reversed(range(num_blocks)):
235
+ filters = self.filters * self.channel_multipliers[i]
236
+ for _ in range(self.num_res_blocks + 1):
237
+ x = ResBlock(filters, **block_args)(x)
238
+ if i > 0:
239
+ #We do pixel channel downsampling every time we downsample spatially.
240
+ pixel = self.pixels(x)
241
+ print("pre up", x.shape)
242
+ x = upsample(x, 2)
243
+ print("post up", x.shape)
244
+ x = x + pixel
245
+ x = nn.Conv(filters, kernel_size=(3, 3))(x)
246
+ print('Decoder layer', x.shape)
247
+ x = norm_fn()(x)
248
+ x = self.activation_fn(x)
249
+ x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
250
+ return x
251
+
252
+ class VectorQuantizer(nn.Module):
253
+ """Basic vector quantizer."""
254
+ config: ml_collections.ConfigDict
255
+ train: bool
256
+
257
+ @nn.compact
258
+ def __call__(self, x):
259
+ codebook_size = self.config.codebook_size
260
+ emb_dim = x.shape[-1]
261
+ codebook = self.param(
262
+ "codebook",
263
+ jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
264
+ (codebook_size, emb_dim))
265
+ codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
266
+ distances = jnp.reshape(
267
+ squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
268
+ x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
269
+ encoding_indices = jnp.argmin(distances, axis=-1)
270
+ encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
271
+ quantized = self.quantize(encoding_onehot)
272
+ result_dict = dict()
273
+ if self.train:
274
+ e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
275
+ q_latent_loss = jnp.mean((quantized - sg(x))**2)
276
+ entropy_loss = 0.0
277
+ if self.config.entropy_loss_ratio != 0:
278
+ entropy_loss = entropy_loss_fn(
279
+ -distances,
280
+ loss_type=self.config.entropy_loss_type,
281
+ temperature=self.config.entropy_temperature
282
+ ) * self.config.entropy_loss_ratio
283
+ e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
284
+ q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
285
+ entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
286
+ loss = e_latent_loss + q_latent_loss + entropy_loss
287
+ result_dict = dict(
288
+ quantizer_loss=loss,
289
+ e_latent_loss=e_latent_loss,
290
+ q_latent_loss=q_latent_loss,
291
+ entropy_loss=entropy_loss)
292
+ quantized = x + jax.lax.stop_gradient(quantized - x)
293
+
294
+ result_dict.update({
295
+ "z_ids": encoding_indices,
296
+ })
297
+ return quantized, result_dict
298
+
299
+ def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
300
+ codebook = jnp.asarray(self.variables["params"]["codebook"])
301
+ return jnp.dot(encoding_onehot, codebook)
302
+
303
+ def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
304
+ codebook = self.variables["params"]["codebook"]
305
+ return jnp.take(codebook, ids, axis=0)
306
+
307
+ class KLQuantizer(nn.Module):
308
+ config: ml_collections.ConfigDict
309
+ train: bool
310
+
311
+ @nn.compact
312
+ def __call__(self, x):
313
+ emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
314
+ means = x[..., :emb_dim]
315
+ logvars = x[..., emb_dim:]
316
+ if not self.train:
317
+ result_dict = dict()
318
+ return means, result_dict
319
+ else:
320
+ noise = jax.random.normal(self.make_rng("noise"), means.shape)
321
+ stds = jnp.exp(0.5 * logvars)
322
+ z = means + stds * noise
323
+ kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
324
+ result_dict = dict(quantizer_loss=kl_loss)
325
+ return z, result_dict
326
+
327
+ class AEQuantizer(nn.Module): #cooking
328
+ config: ml_collections.ConfigDict
329
+ train: bool
330
+
331
+ @nn.compact
332
+ def __call__(self, x):
333
+ result_dict = dict()
334
+ return x, result_dict
335
+
336
+ from einops import rearrange
337
+
338
+ def pixel_unshuffle(x, factor):
339
+
340
+ x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=factor, b2=factor)
341
+ return x
342
+ def pixel_shuffle(x, factor):
343
+ x = rearrange(x, '... h w (c b1 b2) -> ... (h b1) (w b2) c', b1=factor, b2=factor)
344
+ return x
345
+
346
+ class KLQuantizerTwo(nn.Module):
347
+ config: ml_collections.ConfigDict
348
+ train: bool
349
+
350
+ @nn.compact
351
+ def __call__(self, x):
352
+ #emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
353
+ #means = x[..., :emb_dim]
354
+ #logvars = x[..., emb_dim:]
355
+
356
+ #Wwe actually wanna do mean and STD on the batch axis?
357
+
358
+
359
+ #we start as b hw 8, go to b hw 4, with mean and std over those.
360
+
361
+ if not self.train:
362
+ result_dict = dict()
363
+ return x, result_dict
364
+ else:
365
+ #Previous run is mean over axis 0..
366
+ means = jnp.mean(x, axis = [1,2,3])
367
+ stds = jnp.std(x, axis = [1,2,3])
368
+
369
+ noise = jax.random.normal(self.make_rng("noise"), means.shape)
370
+
371
+ logvars = .5 * jnp.log(stds)
372
+
373
+ z = means + stds * noise
374
+ #We just... don't need to return Z for this, but instead we return X
375
+ #This is the denoising version
376
+ kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
377
+ result_dict = dict(quantizer_loss=kl_loss)
378
+ return x, result_dict
379
+
380
+
381
+ class FSQuantizer(nn.Module):
382
+ config: ml_collections.ConfigDict
383
+ train: bool
384
+
385
+ @nn.compact
386
+ def __call__(self, x):
387
+ assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
388
+ z = jnp.tanh(x) # [-1, 1]
389
+ z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
390
+ zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
391
+ quantized = z + jax.lax.stop_gradient(zhat - z)
392
+ quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
393
+ result_dict = dict()
394
+
395
+ # Diagnostics for codebook usage.
396
+ zhat_scaled = zhat + self.config['fsq_levels'] // 2
397
+ basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
398
+ idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
399
+ idx_flat = idx.reshape(-1)
400
+ usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
401
+
402
+ result_dict.update({
403
+ "z_ids": zhat,
404
+ 'usage': usage
405
+ })
406
+ return quantized, result_dict
407
+
408
+ class VQVAE(nn.Module):
409
+ """VQVAE model."""
410
+ config: ml_collections.ConfigDict
411
+ train: bool
412
+
413
+ def setup(self):
414
+ """VQVAE setup."""
415
+ if self.config['quantizer_type'] == 'vq':
416
+ self.quantizer = VectorQuantizer(config=self.config, train=self.train)
417
+ elif self.config['quantizer_type'] == 'kl':
418
+ self.quantizer = KLQuantizer(config=self.config, train=self.train)
419
+ elif self.config['quantizer_type'] == 'fsq':
420
+ self.quantizer = FSQuantizer(config=self.config, train=self.train)
421
+ elif self.config['quantizer_type'] == 'ae':
422
+ self.quantizer = AEQuantizer(config=self.config, train=self.train)
423
+ elif self.config["quantizer_type"] == "kl_two":
424
+ self.quantizer = KLQuantizerTwo(config=self.config, train=self.train)
425
+ self.encoder = Encoder(config=self.config)
426
+ self.decoder = Decoder(config=self.config)
427
+
428
+ def encode(self, image):
429
+ encoded_feature = self.encoder(image)
430
+ quantized, result_dict = self.quantizer(encoded_feature)
431
+ print("After quant", quantized.shape)
432
+ return quantized, result_dict
433
+
434
+ def decode(self, z_vectors):
435
+ print("z_vectors shape", z_vectors.shape)
436
+ reconstructed = self.decoder(z_vectors)
437
+ return reconstructed
438
+
439
+ def decode_from_indices(self, z_ids):
440
+ z_vectors = self.quantizer.decode_ids(z_ids)
441
+ reconstructed_image = self.decode(z_vectors)
442
+ return reconstructed_image
443
+
444
+ def encode_to_indices(self, image):
445
+ encoded_feature = self.encoder(image)
446
+ _, result_dict = self.quantizer(encoded_feature)
447
+ ids = result_dict["z_ids"]
448
+ return ids
449
+
450
+ def __call__(self, input_dict):
451
+ quantized, result_dict = self.encode(input_dict)
452
+ outputs = self.decoder(quantized)
453
+ return outputs, result_dict