|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Simple VAE fork of the UViM VQ-VAE (proj/uvim/vit.py) with small changes."""
|
|
|
|
from typing import Optional, Sequence, Mapping, Any
|
|
|
|
from big_vision import utils
|
|
from big_vision.models import common
|
|
from big_vision.models import vit
|
|
from big_vision.models.proj.givt import vae
|
|
|
|
import einops
|
|
import flax.linen as nn
|
|
import flax.training.checkpoints
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
class Model(vae.Model):
|
|
"""ViT model."""
|
|
|
|
input_size: Sequence[int] = (256, 256)
|
|
patch_size: Sequence[int] = (16, 16)
|
|
width: int = 768
|
|
enc_depth: int = 6
|
|
dec_depth: int = 6
|
|
mlp_dim: Optional[int] = None
|
|
num_heads: int = 12
|
|
posemb: str = "learn"
|
|
dropout: float = 0.0
|
|
head_zeroinit: bool = True
|
|
bottleneck_resize: bool = False
|
|
inout_specs: Optional[Mapping[str, tuple[int, int]]] = None
|
|
scan: bool = False
|
|
remat_policy: str = "nothing_saveable"
|
|
|
|
def setup(self) -> None:
|
|
self.grid_size = np.array(self.input_size) // np.array(self.patch_size)
|
|
|
|
self.embedding = nn.Conv(
|
|
self.width, self.patch_size, strides=self.patch_size,
|
|
padding="VALID", name="embedding")
|
|
|
|
self.pos_embedding_encoder = vit.get_posemb(
|
|
self, self.posemb, self.grid_size, self.width, "pos_embedding_encoder")
|
|
self.encoder = vit.Encoder(
|
|
depth=self.enc_depth,
|
|
mlp_dim=self.mlp_dim,
|
|
num_heads=self.num_heads,
|
|
dropout=self.dropout,
|
|
scan=self.scan,
|
|
remat_policy=self.remat_policy,
|
|
name="encoder")
|
|
|
|
if not self.bottleneck_resize:
|
|
self.bottleneck_downsample = self.param(
|
|
"bottleneck_downsample",
|
|
nn.initializers.xavier_uniform(),
|
|
(np.prod(self.grid_size), self.code_len))
|
|
|
|
if not self.bottleneck_resize:
|
|
self.bottleneck_upsample = self.param(
|
|
"bottleneck_upsample",
|
|
nn.initializers.xavier_uniform(),
|
|
(self.code_len, np.prod(self.grid_size)))
|
|
|
|
self.pos_embedding_decoder = vit.get_posemb(
|
|
self, self.posemb, self.grid_size, self.width, "pos_embedding_decoder")
|
|
self.decoder = vit.Encoder(
|
|
depth=self.dec_depth,
|
|
mlp_dim=self.mlp_dim,
|
|
num_heads=self.num_heads,
|
|
dropout=self.dropout,
|
|
scan=self.scan,
|
|
remat_policy=self.remat_policy,
|
|
name="decoder")
|
|
|
|
|
|
|
|
self.encoder_head = nn.Dense(self.codeword_dim * 2 or self.width * 2)
|
|
self.decoder_stem = nn.Dense(self.width)
|
|
|
|
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
|
|
|
if self.inout_specs is not None:
|
|
num_out_channels = sum(
|
|
num_classes for _, num_classes in self.inout_specs.values())
|
|
else:
|
|
num_out_channels = 3
|
|
|
|
self.head = nn.Dense(
|
|
num_out_channels * np.prod(self.patch_size),
|
|
name="decoder_head", **kw)
|
|
|
|
def encode(
|
|
self,
|
|
x: jax.Array,
|
|
*,
|
|
train: bool = False,
|
|
) -> tuple[jax.Array, jax.Array]:
|
|
if self.inout_specs is not None:
|
|
one_hot_inputs = []
|
|
for in_ch, num_classes in self.inout_specs.values():
|
|
one_hot_inputs.append(nn.one_hot(x[..., in_ch], num_classes))
|
|
x = jnp.concatenate(one_hot_inputs, axis=-1)
|
|
x = self.embedding(x)
|
|
x = einops.rearrange(x, "b h w c -> b (h w) c")
|
|
|
|
x, _ = self.encoder(x + self.pos_embedding_encoder, deterministic=not train)
|
|
|
|
if self.bottleneck_resize:
|
|
x = einops.rearrange(x, "b (h w) c -> b h w c",
|
|
h=self.grid_size[0], w=self.grid_size[1])
|
|
l = int(np.round(self.code_len ** 0.5))
|
|
x = jax.image.resize(
|
|
x, (x.shape[0], l, l, x.shape[3]),
|
|
method="linear")
|
|
x = einops.rearrange(x, "b h w c -> b (h w) c")
|
|
else:
|
|
x = jnp.einsum("btc,tn->bnc", x, self.bottleneck_downsample)
|
|
|
|
x = self.encoder_head(x)
|
|
|
|
mu, logvar = jnp.split(x, 2, axis=-1)
|
|
return mu, logvar
|
|
|
|
def decode(
|
|
self,
|
|
x: jax.Array,
|
|
train: bool = False,
|
|
) -> jax.Array | Mapping[str, jax.Array]:
|
|
x = self.decoder_stem(x)
|
|
|
|
if self.bottleneck_resize:
|
|
l = int(np.round(self.code_len ** 0.5))
|
|
x = einops.rearrange(x, "b (h w) c -> b h w c", h=l, w=l)
|
|
x = jax.image.resize(
|
|
x, (x.shape[0], self.grid_size[0], self.grid_size[1], x.shape[3]),
|
|
method="linear")
|
|
x = einops.rearrange(x, "b h w c -> b (h w) c")
|
|
else:
|
|
x = jnp.einsum("bnc,nt->btc", x, self.bottleneck_upsample)
|
|
|
|
x, _ = self.decoder(x + self.pos_embedding_decoder, deterministic=not train)
|
|
x = self.head(x)
|
|
|
|
x = einops.rearrange(x, "b (h w) (p q c) -> b (h p) (w q) c",
|
|
h=self.grid_size[0], w=self.grid_size[1],
|
|
p=self.patch_size[0], q=self.patch_size[1])
|
|
|
|
if self.inout_specs is None:
|
|
x = jnp.clip(x, -1.0, 1.0)
|
|
else:
|
|
x_dict = {}
|
|
channel_index = 0
|
|
for name, (_, num_channels) in self.inout_specs.items():
|
|
x_dict[name] = x[..., channel_index : channel_index + num_channels]
|
|
channel_index += num_channels
|
|
x = x_dict
|
|
|
|
return x
|
|
|
|
|
|
def load(
|
|
init_params: Any,
|
|
init_file: str,
|
|
model_params: Any = None,
|
|
dont_load: Sequence[str] = (),
|
|
) -> Any:
|
|
"""Loads params from init checkpoint and merges into init_params."""
|
|
del model_params
|
|
params = flax.core.unfreeze(utils.load_params(init_file))
|
|
if init_params is not None:
|
|
params = common.merge_params(params, init_params, dont_load)
|
|
return params
|
|
|