# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers import math from functools import partial from typing import Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): Decoded output sample of the model. Output of the last layer of the model. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ sample: jnp.ndarray @flax.struct.dataclass class FlaxAutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`FlaxDiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "FlaxDiagonalGaussianDistribution" class FlaxUpsample2D(nn.Module): """ Flax implementation of 2D Upsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): """ Flax implementation of 2D Downsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(2, 2), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): """ Flax implementation of 2D Resnet Block. Args: in_channels (`int`): Input channels out_channels (`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int = None dropout: float = 0.0 groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual class FlaxAttentionBlock(nn.Module): r""" Flax Convolutional based multi-head attention block for diffusion-based VAE. Parameters: channels (:obj:`int`): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads num_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 dense = partial(nn.Dense, self.channels, dtype=self.dtype) self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() def transpose_for_scores(self, projection): new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) new_projection = projection.reshape(new_projection_shape) # (B, T, H, D) -> (B, H, T, D) new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) return new_projection def __call__(self, hidden_states): residual = hidden_states batch, height, width, channels = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.reshape((batch, height * width, channels)) query = self.query(hidden_states) key = self.key(hidden_states) value = self.value(hidden_states) # transpose query = self.transpose_for_scores(query) key = self.transpose_for_scores(key) value = self.transpose_for_scores(value) # compute attentions scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) attn_weights = nn.softmax(attn_weights, axis=-1) # attend to values hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) hidden_states = hidden_states.reshape(new_hidden_states_shape) hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.reshape((batch, height, width, channels)) hidden_states = hidden_states + residual return hidden_states class FlaxDownEncoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Encoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) return hidden_states class FlaxUpDecoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2D(nn.Module): r""" Flax Unet Mid-Block module. Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet and Attention block group norm attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 attn_num_head_channels: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.attn_num_head_channels, num_groups=resnet_groups, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, deterministic=deterministic) return hidden_states class FlaxEncoder(nn.Module): r""" Flax Implementation of VAE Encoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): DownEncoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" double_z: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # in self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # downsampling down_blocks = [] output_channel = block_out_channels[0] for i, _ in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = FlaxDownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, attn_num_head_channels=None, dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # in sample = self.conv_in(sample) # downsampling for block in self.down_blocks: sample = block(sample, deterministic=deterministic) # middle sample = self.mid_block(sample, deterministic=deterministic) # end sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDecoder(nn.Module): r""" Flax Implementation of VAE Decoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): UpDecoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: int = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # z to block_in self.conv_in = nn.Conv( block_out_channels[-1], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, attn_num_head_channels=None, dtype=self.dtype, ) # upsampling reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] up_blocks = [] for i, _ in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # end self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # z to block_in sample = self.conv_in(sample) # middle sample = self.mid_block(sample, deterministic=deterministic) # upsampling for block in self.up_blocks: sample = block(sample, deterministic=deterministic) sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): # Last axis to account for channels-last self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) self.logvar = jnp.clip(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = jnp.exp(0.5 * self.logvar) self.var = jnp.exp(self.logvar) if self.deterministic: self.var = self.std = jnp.zeros_like(self.mean) def sample(self, key): return self.mean + self.std * jax.random.normal(key, self.mean.shape) def kl(self, other=None): if self.deterministic: return jnp.array([0.0]) if other is None: return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) return 0.5 * jnp.sum( jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, axis=[1, 2, 3], ) def nll(self, sample, axis=[1, 2, 3]): if self.deterministic: return jnp.array([0.0]) logtwopi = jnp.log(2.0 * jnp.pi) return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) def mode(self): return self.mean @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): r""" Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): DownEncoder block type up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): UpDecoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function latent_channels (:obj:`int`, *optional*, defaults to `4`): Latent space channels norm_num_groups (:obj:`int`, *optional*, defaults to `32`): Norm num group sample_size (:obj:`int`, *optional*, defaults to `32`): Sample input size dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = "silu" latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, down_block_types=self.config.down_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, act_fn=self.config.act_fn, norm_num_groups=self.config.norm_num_groups, double_z=True, dtype=self.dtype, ) self.decoder = FlaxDecoder( in_channels=self.config.latent_channels, out_channels=self.config.out_channels, up_block_types=self.config.up_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, norm_num_groups=self.config.norm_num_groups, act_fn=self.config.act_fn, dtype=self.dtype, ) self.quant_conv = nn.Conv( 2 * self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.post_quant_conv = nn.Conv( self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) def decode(self, latents, deterministic: bool = True, return_dict: bool = True): if latents.shape[-1] != self.config.latent_channels: latents = jnp.transpose(latents, (0, 2, 3, 1)) hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) if not return_dict: return (hidden_states,) return FlaxDecoderOutput(sample=hidden_states) def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) if sample_posterior: rng = self.make_rng("gaussian") hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,) return FlaxDecoderOutput(sample=sample)