Spaces:
Runtime error
Runtime error
import jax.numpy as jnp | |
import flax.linen as nn | |
from t5_vae_flax_alt.src.encoders import VAE_ENCODER_MODELS | |
from t5_vae_flax_alt.src.decoders import VAE_DECODER_MODELS | |
from t5_vae_flax_alt.src.config import T5VaeConfig | |
class VAE(nn.Module): | |
# see https://github.com/google/flax#what-does-flax-look-like | |
""" | |
An MMD-VAE used with encoder-decoder models. | |
Encodes all token encodings into a single latent & spits them back out. | |
""" | |
config: T5VaeConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_token_size, self.config.n_latent_tokens) | |
self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens) | |
def __call__(self, encoding=None, latent_codes=None): | |
latent_codes = self.encode(encoding) | |
return self.decode(latent_codes), latent_codes | |
def encode(self, encoding): | |
return self.encoder(encoding) | |
def decode(self, latent): | |
return self.decoder(latent) | |