Spaces:
Runtime error
Runtime error
import logging | |
import flax.linen as nn | |
logger = logging.getLogger(__name__) | |
class Decoder(nn.Module): | |
''' | |
Converts latent code -> transformer encoding. | |
''' | |
dim_model: int | |
n_latent_tokens: int | |
def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim) | |
raw_latent_tokens = nn.Dense(self.dim_model)(latent_code) | |
latent_tokens = nn.LayerNorm()(raw_latent_tokens) | |
return latent_tokens # (batch, latent_tokens_per_sequence, dim_model) | |
VAE_DECODER_MODELS = { | |
'': Decoder, | |
} | |