|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""VQ-VAE autoencoder with ViT backbone."""
|
|
|
|
import functools
|
|
from typing import Mapping, Optional, Sequence, Union
|
|
|
|
from big_vision import utils
|
|
from big_vision.models import common
|
|
from big_vision.models import vit
|
|
|
|
import einops
|
|
import flax.linen as nn
|
|
import flax.training.checkpoints
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
partial = functools.partial
|
|
|
|
|
|
|
|
|
|
|
|
PERTURB = 0.001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0))
|
|
@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0))
|
|
def quantize(x, e):
|
|
dist = jnp.sum(x * x)[None] - 2 * x.dot(e.T) + jnp.sum(e * e, axis=1)
|
|
idx = jnp.argmin(dist)
|
|
x_q = jax.lax.stop_gradient(e[idx] - x) + x
|
|
return x_q, idx
|
|
|
|
|
|
def split_the_most_frequent_embedding(state):
|
|
"""Splits most frequent embedding into two and eliminates least frequent.
|
|
|
|
Args:
|
|
state: a dict. that contains current jax rng, embeddings and their counts.
|
|
|
|
Returns:
|
|
New dict. with the updated jax rng, embeddings and counts.
|
|
"""
|
|
rng, e, c = state["rng"], state["dictionary"], state["counts"]
|
|
rng, rng_local = jax.random.split(rng)
|
|
|
|
i_max = jnp.argmax(c)
|
|
i_min = jnp.argmin(c)
|
|
|
|
e = e.at[i_min].set(
|
|
e[i_max] * jax.random.uniform(rng_local, (e.shape[1],), jnp.float32,
|
|
1.0-PERTURB, 1.0+PERTURB))
|
|
|
|
c = c.at[i_min].set(c[i_max] / 2.0)
|
|
c = c.at[i_max].set(c[i_max] / 2.0)
|
|
|
|
e = e.at[i_min].set(e[i_min] / 2.0)
|
|
e = e.at[i_max].set(e[i_max] / 2.0)
|
|
|
|
return {"rng": rng, "dictionary": e, "counts": c}
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""ViT model."""
|
|
|
|
inputs: Mapping[str, Sequence[int]]
|
|
outputs: Mapping[str, Sequence[int]]
|
|
input_size: Sequence[int] = (256, 256)
|
|
patch_size: Sequence[int] = (8, 8)
|
|
code_len: int = 256
|
|
width: int = 768
|
|
enc_depth: int = 6
|
|
dec_depth: int = 6
|
|
mlp_dim: Optional[int] = None
|
|
num_heads: int = 12
|
|
posemb: str = "learn"
|
|
rep_size: Union[int, bool] = False
|
|
dropout: float = 0.0
|
|
reinit: Optional[Sequence[str]] = None
|
|
head_zeroinit: bool = True
|
|
dict_size: int = 512
|
|
codeword_dim: Optional[int] = None
|
|
dict_momentum: float = 0.995
|
|
quantize: bool = True
|
|
|
|
statistics_axis_name: str = "batch"
|
|
|
|
|
|
|
|
min_count: float = 0.1
|
|
with_encoder_ctx: bool = False
|
|
with_decoder_ctx: bool = False
|
|
code_dropout: str = "none"
|
|
bottleneck_resize: bool = False
|
|
zero_decoder_seq: bool = False
|
|
|
|
def setup(self):
|
|
|
|
self.grid_size = np.array(self.input_size) // np.array(self.patch_size)
|
|
|
|
self.embeddings = {
|
|
k: nn.DenseGeneral(features=(self.width,), axis=range(-len(shape), 0),
|
|
name=f"embedding_{k}")
|
|
for k, shape in self.inputs.items()
|
|
}
|
|
|
|
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
|
self.heads = {
|
|
k: nn.DenseGeneral(features=shape, name=f"head_{k}", **kw)
|
|
for k, shape in self.outputs.items()
|
|
}
|
|
|
|
if self.with_encoder_ctx:
|
|
self.stem_conv_ctx_enc = nn.Conv(
|
|
self.width, self.patch_size, strides=self.patch_size,
|
|
padding="VALID", name="ctx_enc_embedding")
|
|
|
|
if self.with_decoder_ctx:
|
|
self.stem_conv_ctx_dec = nn.Conv(
|
|
self.width, self.patch_size, strides=self.patch_size,
|
|
padding="VALID", name="ctx_dec_embedding")
|
|
|
|
self.pos_embedding_encoder = vit.get_posemb(
|
|
self, self.posemb, self.grid_size, self.width, "pos_embedding_encoder")
|
|
self.encoder = vit.Encoder(
|
|
depth=self.enc_depth,
|
|
mlp_dim=self.mlp_dim,
|
|
num_heads=self.num_heads,
|
|
dropout=self.dropout,
|
|
name="encoder")
|
|
|
|
if not self.bottleneck_resize:
|
|
self.bottleneck_downsample = self.param(
|
|
"bottleneck_downsample",
|
|
nn.initializers.xavier_uniform(),
|
|
(np.prod(self.grid_size), self.code_len))
|
|
|
|
norm_init = nn.initializers.normal(stddev=1.0 / np.sqrt(self.dict_size))
|
|
self.dictionary = self.variable(
|
|
"state", "dictionary",
|
|
lambda shape: norm_init(self.make_rng("state"), shape),
|
|
(self.dict_size, self.codeword_dim or self.width))
|
|
self.counts = self.variable("state", "counts", jnp.ones, (self.dict_size,))
|
|
|
|
if not self.bottleneck_resize:
|
|
self.bottleneck_upsample = self.param(
|
|
"bottleneck_upsample",
|
|
nn.initializers.xavier_uniform(),
|
|
(self.code_len, np.prod(self.grid_size)))
|
|
|
|
self.pos_embedding_decoder = vit.get_posemb(
|
|
self, self.posemb, self.grid_size, self.width, "pos_embedding_decoder")
|
|
self.decoder = vit.Encoder(
|
|
depth=self.dec_depth,
|
|
mlp_dim=self.mlp_dim,
|
|
num_heads=self.num_heads,
|
|
dropout=self.dropout,
|
|
name="decoder")
|
|
|
|
self.encoder_head = nn.Dense(self.codeword_dim or self.width)
|
|
self.decoder_stem = nn.Dense(self.width)
|
|
|
|
def get_codewords(self):
|
|
e = self.dictionary.value / self.counts.value[:, None]
|
|
e = e / jnp.linalg.norm(e, axis=-1, keepdims=True)
|
|
return e
|
|
|
|
def encode(self, x, *, ctx=None, train=False, update_dict=True):
|
|
out = {}
|
|
|
|
out["stem"] = {}
|
|
for key, embed in self.embeddings.items():
|
|
out["stem"][key] = embed(x[key])
|
|
x = sum(out["stem"].values())
|
|
|
|
if self.with_encoder_ctx:
|
|
ctx_tokens = self.stem_conv_ctx_enc(ctx)
|
|
ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c")
|
|
x = x + ctx_tokens
|
|
|
|
x, _ = self.encoder(x + self.pos_embedding_encoder, deterministic=not train)
|
|
|
|
if self.bottleneck_resize:
|
|
x = einops.rearrange(x, "b (h w) c -> b h w c",
|
|
h=self.grid_size[0], w=self.grid_size[1])
|
|
l = int(np.round(self.code_len ** 0.5))
|
|
x = jax.image.resize(
|
|
x, (x.shape[0], l, l, x.shape[3]),
|
|
method="linear")
|
|
x = einops.rearrange(x, "b h w c -> b (h w) c")
|
|
else:
|
|
x = jnp.einsum("btc,tn->bnc", x, self.bottleneck_downsample)
|
|
|
|
x = self.encoder_head(x)
|
|
|
|
x = jax.nn.standardize(x, axis=-1)
|
|
x_pre_q = out["bottleneck"] = x
|
|
e = self.get_codewords()
|
|
x, idx = quantize(x, e)
|
|
out["bottleneck_q"] = x
|
|
out["code"] = idx
|
|
|
|
|
|
|
|
|
|
|
|
if train:
|
|
|
|
counts = jnp.zeros(self.dict_size, dtype=jnp.int32)
|
|
counts = counts.at[idx].add(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_sum = jnp.zeros_like(self.dictionary.value)
|
|
x_sum = x_sum.at[idx].add(jax.lax.stop_gradient(x_pre_q))
|
|
|
|
if self.statistics_axis_name:
|
|
counts = jax.lax.psum(counts, axis_name=self.statistics_axis_name)
|
|
x_sum = jax.lax.psum(x_sum, axis_name=self.statistics_axis_name)
|
|
|
|
out["codebook_max_ratio"] = jnp.max(counts) / jnp.sum(counts)
|
|
out["codebook_zeros_ratio"] = jnp.sum(counts == 0) / len(counts)
|
|
|
|
if update_dict:
|
|
self.counts.value = self.counts.value * self.dict_momentum + counts
|
|
self.dictionary.value = (self.dictionary.value * self.dict_momentum +
|
|
x_sum)
|
|
|
|
state = {"dictionary": self.dictionary.value,
|
|
"counts": self.counts.value,
|
|
"rng": self.make_rng("vqvae")}
|
|
new_state = jax.lax.while_loop(
|
|
lambda state: jnp.any(state["counts"] < self.min_count),
|
|
split_the_most_frequent_embedding,
|
|
state)
|
|
self.counts.value = new_state["counts"]
|
|
self.dictionary.value = new_state["dictionary"]
|
|
|
|
if not self.quantize:
|
|
x = x_pre_q
|
|
out["bottleneck_q"] = x
|
|
return x, out
|
|
|
|
def decode(self, x, ctx=None, discrete_input=False, train=False):
|
|
out = {}
|
|
|
|
if discrete_input:
|
|
e = self.get_codewords()
|
|
x = e[x]
|
|
|
|
if self.zero_decoder_seq:
|
|
x = jnp.zeros_like(x)
|
|
|
|
if train and self.code_dropout != "none":
|
|
importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1]
|
|
thr = jax.random.uniform(self.make_rng("dropout"), x.shape[:1])
|
|
mask = importance[None, :] > thr[:, None]
|
|
if self.code_dropout == "random":
|
|
mask = jax.random.permutation(
|
|
self.make_rng("dropout"), mask, axis=-1, independent=True)
|
|
x = x * mask[:, :, None]
|
|
|
|
x = self.decoder_stem(x)
|
|
|
|
if self.bottleneck_resize:
|
|
l = int(np.round(self.code_len ** 0.5))
|
|
x = einops.rearrange(x, "b (h w) c -> b h w c", h=l, w=l)
|
|
x = jax.image.resize(
|
|
x, (x.shape[0], self.grid_size[0], self.grid_size[1], x.shape[3]),
|
|
method="linear")
|
|
x = einops.rearrange(x, "b h w c -> b (h w) c")
|
|
else:
|
|
x = jnp.einsum("bnc,nt->btc", x, self.bottleneck_upsample)
|
|
|
|
if self.with_decoder_ctx:
|
|
ctx_tokens = self.stem_conv_ctx_dec(ctx)
|
|
ctx_tokens = einops.rearrange(ctx_tokens, "b h w c -> b (h w) c")
|
|
x = x + ctx_tokens
|
|
|
|
x, _ = self.decoder(x + self.pos_embedding_decoder)
|
|
|
|
out["logits"] = {}
|
|
for key, head in self.heads.items():
|
|
out["logits"][key] = head(x)
|
|
|
|
return out["logits"], out
|
|
|
|
def __call__(self, x, *, ctx=None, train=False, update_dict=True):
|
|
x, out_enc = self.encode(x, ctx=ctx, train=train, update_dict=update_dict)
|
|
x, out_dec = self.decode(x, ctx=ctx, train=train)
|
|
return x, {**out_enc, **out_dec}
|
|
|
|
|
|
def load(init_params, init_file, model_params=None, dont_load=()):
|
|
"""Loads params from init checkpoint and merges into init_params."""
|
|
del model_params
|
|
ckpt = flax.core.unfreeze(utils.load_checkpoint(None, init_file))
|
|
params = {"params": ckpt["params"], "state": ckpt["state"]}
|
|
params = flax.training.checkpoints.convert_pre_linen(params)
|
|
|
|
if "Encoder" in params["params"]:
|
|
p = params["params"]
|
|
p["encoder"] = p.pop("Encoder")
|
|
p["decoder"] = p.pop("Decoder")
|
|
params["params"] = p
|
|
if init_params is not None:
|
|
params = common.merge_params(params, init_params, dont_load)
|
|
return params["params"], params["state"]
|
|
|