|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Abstract VAE model class.
|
|
|
|
Gaussian encoder and decoder (the latter assumed to have constant variance).
|
|
|
|
Inspiration drawn from https://github.com/pytorch/examples/tree/main/vae.
|
|
"""
|
|
|
|
import abc
|
|
from typing import Optional, Mapping
|
|
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
class Model(nn.Module, metaclass=abc.ABCMeta):
|
|
"""Abstract VAE model class."""
|
|
|
|
codeword_dim: Optional[int] = None
|
|
code_len: int = 256
|
|
code_dropout: str = "none"
|
|
|
|
@abc.abstractmethod
|
|
def encode(
|
|
self,
|
|
x: jax.Array,
|
|
*,
|
|
train: bool = False,
|
|
) -> tuple[jax.Array, jax.Array]:
|
|
...
|
|
|
|
def reparametrize(
|
|
self,
|
|
mu: jax.Array,
|
|
logvar: jax.Array,
|
|
rng: jax.Array | None = None,
|
|
) -> jax.Array:
|
|
std = jnp.exp(0.5 * logvar)
|
|
if rng is None:
|
|
rng = self.make_rng("dropout")
|
|
eps = jax.random.normal(rng, shape=std.shape, dtype=std.dtype)
|
|
return mu + std * eps
|
|
|
|
@abc.abstractmethod
|
|
def decode(
|
|
self, x: jax.Array,
|
|
train: bool = False,
|
|
) -> jax.Array | Mapping[str, jax.Array]:
|
|
...
|
|
|
|
def code_dropout_fn(self, z: jax.Array, *, train: bool = False) -> jax.Array:
|
|
|
|
|
|
assert self.code_dropout in ["none", "seq", "random"]
|
|
if train and self.code_dropout != "none":
|
|
importance = jnp.linspace(1.0, 0.0, self.code_len + 2)[1:-1]
|
|
thr = jax.random.uniform(self.make_rng("dropout"), z.shape[:1])
|
|
mask = importance[None, :] > thr[:, None]
|
|
if self.code_dropout == "random":
|
|
mask = jax.random.permutation(
|
|
self.make_rng("dropout"), mask, axis=-1, independent=True)
|
|
z = z * mask[:, :, None]
|
|
return z
|
|
|
|
def __call__(
|
|
self,
|
|
x: jax.Array,
|
|
*,
|
|
train: bool = False,
|
|
) -> tuple[jax.Array | Mapping[str, jax.Array], Mapping[str, jax.Array]]:
|
|
mu, logvar = self.encode(x, train=train)
|
|
|
|
if train:
|
|
z = self.reparametrize(mu, logvar)
|
|
else:
|
|
z = mu
|
|
z = self.code_dropout_fn(z, train=train)
|
|
x = self.decode(z, train=train)
|
|
return x, {"mu": mu, "logvar": logvar, "z": z}
|
|
|