Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers | |
| import math | |
| from functools import partial | |
| from typing import Tuple | |
| import flax | |
| import flax.linen as nn | |
| import jax | |
| import jax.numpy as jnp | |
| from flax.core.frozen_dict import FrozenDict | |
| from ..configuration_utils import ConfigMixin, flax_register_to_config | |
| from ..utils import BaseOutput | |
| from .modeling_flax_utils import FlaxModelMixin | |
| class FlaxDecoderOutput(BaseOutput): | |
| """ | |
| Output of decoding method. | |
| Args: | |
| sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): | |
| The decoded output sample from the last layer of the model. | |
| dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): | |
| The `dtype` of the parameters. | |
| """ | |
| sample: jnp.ndarray | |
| class FlaxAutoencoderKLOutput(BaseOutput): | |
| """ | |
| Output of AutoencoderKL encoding method. | |
| Args: | |
| latent_dist (`FlaxDiagonalGaussianDistribution`): | |
| Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. | |
| `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. | |
| """ | |
| latent_dist: "FlaxDiagonalGaussianDistribution" | |
| class FlaxUpsample2D(nn.Module): | |
| """ | |
| Flax implementation of 2D Upsample layer | |
| Args: | |
| in_channels (`int`): | |
| Input channels | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| self.conv = nn.Conv( | |
| self.in_channels, | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| def __call__(self, hidden_states): | |
| batch, height, width, channels = hidden_states.shape | |
| hidden_states = jax.image.resize( | |
| hidden_states, | |
| shape=(batch, height * 2, width * 2, channels), | |
| method="nearest", | |
| ) | |
| hidden_states = self.conv(hidden_states) | |
| return hidden_states | |
| class FlaxDownsample2D(nn.Module): | |
| """ | |
| Flax implementation of 2D Downsample layer | |
| Args: | |
| in_channels (`int`): | |
| Input channels | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| self.conv = nn.Conv( | |
| self.in_channels, | |
| kernel_size=(3, 3), | |
| strides=(2, 2), | |
| padding="VALID", | |
| dtype=self.dtype, | |
| ) | |
| def __call__(self, hidden_states): | |
| pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim | |
| hidden_states = jnp.pad(hidden_states, pad_width=pad) | |
| hidden_states = self.conv(hidden_states) | |
| return hidden_states | |
| class FlaxResnetBlock2D(nn.Module): | |
| """ | |
| Flax implementation of 2D Resnet Block. | |
| Args: | |
| in_channels (`int`): | |
| Input channels | |
| out_channels (`int`): | |
| Output channels | |
| dropout (:obj:`float`, *optional*, defaults to 0.0): | |
| Dropout rate | |
| groups (:obj:`int`, *optional*, defaults to `32`): | |
| The number of groups to use for group norm. | |
| use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): | |
| Whether to use `nin_shortcut`. This activates a new layer inside ResNet block | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int | |
| out_channels: int = None | |
| dropout: float = 0.0 | |
| groups: int = 32 | |
| use_nin_shortcut: bool = None | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| out_channels = self.in_channels if self.out_channels is None else self.out_channels | |
| self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) | |
| self.conv1 = nn.Conv( | |
| out_channels, | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) | |
| self.dropout_layer = nn.Dropout(self.dropout) | |
| self.conv2 = nn.Conv( | |
| out_channels, | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut | |
| self.conv_shortcut = None | |
| if use_nin_shortcut: | |
| self.conv_shortcut = nn.Conv( | |
| out_channels, | |
| kernel_size=(1, 1), | |
| strides=(1, 1), | |
| padding="VALID", | |
| dtype=self.dtype, | |
| ) | |
| def __call__(self, hidden_states, deterministic=True): | |
| residual = hidden_states | |
| hidden_states = self.norm1(hidden_states) | |
| hidden_states = nn.swish(hidden_states) | |
| hidden_states = self.conv1(hidden_states) | |
| hidden_states = self.norm2(hidden_states) | |
| hidden_states = nn.swish(hidden_states) | |
| hidden_states = self.dropout_layer(hidden_states, deterministic) | |
| hidden_states = self.conv2(hidden_states) | |
| if self.conv_shortcut is not None: | |
| residual = self.conv_shortcut(residual) | |
| return hidden_states + residual | |
| class FlaxAttentionBlock(nn.Module): | |
| r""" | |
| Flax Convolutional based multi-head attention block for diffusion-based VAE. | |
| Parameters: | |
| channels (:obj:`int`): | |
| Input channels | |
| num_head_channels (:obj:`int`, *optional*, defaults to `None`): | |
| Number of attention heads | |
| num_groups (:obj:`int`, *optional*, defaults to `32`): | |
| The number of groups to use for group norm | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| channels: int | |
| num_head_channels: int = None | |
| num_groups: int = 32 | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 | |
| dense = partial(nn.Dense, self.channels, dtype=self.dtype) | |
| self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) | |
| self.query, self.key, self.value = dense(), dense(), dense() | |
| self.proj_attn = dense() | |
| def transpose_for_scores(self, projection): | |
| new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) | |
| # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) | |
| new_projection = projection.reshape(new_projection_shape) | |
| # (B, T, H, D) -> (B, H, T, D) | |
| new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) | |
| return new_projection | |
| def __call__(self, hidden_states): | |
| residual = hidden_states | |
| batch, height, width, channels = hidden_states.shape | |
| hidden_states = self.group_norm(hidden_states) | |
| hidden_states = hidden_states.reshape((batch, height * width, channels)) | |
| query = self.query(hidden_states) | |
| key = self.key(hidden_states) | |
| value = self.value(hidden_states) | |
| # transpose | |
| query = self.transpose_for_scores(query) | |
| key = self.transpose_for_scores(key) | |
| value = self.transpose_for_scores(value) | |
| # compute attentions | |
| scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) | |
| attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) | |
| attn_weights = nn.softmax(attn_weights, axis=-1) | |
| # attend to values | |
| hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) | |
| hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) | |
| new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) | |
| hidden_states = hidden_states.reshape(new_hidden_states_shape) | |
| hidden_states = self.proj_attn(hidden_states) | |
| hidden_states = hidden_states.reshape((batch, height, width, channels)) | |
| hidden_states = hidden_states + residual | |
| return hidden_states | |
| class FlaxDownEncoderBlock2D(nn.Module): | |
| r""" | |
| Flax Resnet blocks-based Encoder block for diffusion-based VAE. | |
| Parameters: | |
| in_channels (:obj:`int`): | |
| Input channels | |
| out_channels (:obj:`int`): | |
| Output channels | |
| dropout (:obj:`float`, *optional*, defaults to 0.0): | |
| Dropout rate | |
| num_layers (:obj:`int`, *optional*, defaults to 1): | |
| Number of Resnet layer block | |
| resnet_groups (:obj:`int`, *optional*, defaults to `32`): | |
| The number of groups to use for the Resnet block group norm | |
| add_downsample (:obj:`bool`, *optional*, defaults to `True`): | |
| Whether to add downsample layer | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int | |
| out_channels: int | |
| dropout: float = 0.0 | |
| num_layers: int = 1 | |
| resnet_groups: int = 32 | |
| add_downsample: bool = True | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| resnets = [] | |
| for i in range(self.num_layers): | |
| in_channels = self.in_channels if i == 0 else self.out_channels | |
| res_block = FlaxResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=self.out_channels, | |
| dropout=self.dropout, | |
| groups=self.resnet_groups, | |
| dtype=self.dtype, | |
| ) | |
| resnets.append(res_block) | |
| self.resnets = resnets | |
| if self.add_downsample: | |
| self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) | |
| def __call__(self, hidden_states, deterministic=True): | |
| for resnet in self.resnets: | |
| hidden_states = resnet(hidden_states, deterministic=deterministic) | |
| if self.add_downsample: | |
| hidden_states = self.downsamplers_0(hidden_states) | |
| return hidden_states | |
| class FlaxUpDecoderBlock2D(nn.Module): | |
| r""" | |
| Flax Resnet blocks-based Decoder block for diffusion-based VAE. | |
| Parameters: | |
| in_channels (:obj:`int`): | |
| Input channels | |
| out_channels (:obj:`int`): | |
| Output channels | |
| dropout (:obj:`float`, *optional*, defaults to 0.0): | |
| Dropout rate | |
| num_layers (:obj:`int`, *optional*, defaults to 1): | |
| Number of Resnet layer block | |
| resnet_groups (:obj:`int`, *optional*, defaults to `32`): | |
| The number of groups to use for the Resnet block group norm | |
| add_upsample (:obj:`bool`, *optional*, defaults to `True`): | |
| Whether to add upsample layer | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int | |
| out_channels: int | |
| dropout: float = 0.0 | |
| num_layers: int = 1 | |
| resnet_groups: int = 32 | |
| add_upsample: bool = True | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| resnets = [] | |
| for i in range(self.num_layers): | |
| in_channels = self.in_channels if i == 0 else self.out_channels | |
| res_block = FlaxResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=self.out_channels, | |
| dropout=self.dropout, | |
| groups=self.resnet_groups, | |
| dtype=self.dtype, | |
| ) | |
| resnets.append(res_block) | |
| self.resnets = resnets | |
| if self.add_upsample: | |
| self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) | |
| def __call__(self, hidden_states, deterministic=True): | |
| for resnet in self.resnets: | |
| hidden_states = resnet(hidden_states, deterministic=deterministic) | |
| if self.add_upsample: | |
| hidden_states = self.upsamplers_0(hidden_states) | |
| return hidden_states | |
| class FlaxUNetMidBlock2D(nn.Module): | |
| r""" | |
| Flax Unet Mid-Block module. | |
| Parameters: | |
| in_channels (:obj:`int`): | |
| Input channels | |
| dropout (:obj:`float`, *optional*, defaults to 0.0): | |
| Dropout rate | |
| num_layers (:obj:`int`, *optional*, defaults to 1): | |
| Number of Resnet layer block | |
| resnet_groups (:obj:`int`, *optional*, defaults to `32`): | |
| The number of groups to use for the Resnet and Attention block group norm | |
| num_attention_heads (:obj:`int`, *optional*, defaults to `1`): | |
| Number of attention heads for each attention block | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int | |
| dropout: float = 0.0 | |
| num_layers: int = 1 | |
| resnet_groups: int = 32 | |
| num_attention_heads: int = 1 | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) | |
| # there is always at least one resnet | |
| resnets = [ | |
| FlaxResnetBlock2D( | |
| in_channels=self.in_channels, | |
| out_channels=self.in_channels, | |
| dropout=self.dropout, | |
| groups=resnet_groups, | |
| dtype=self.dtype, | |
| ) | |
| ] | |
| attentions = [] | |
| for _ in range(self.num_layers): | |
| attn_block = FlaxAttentionBlock( | |
| channels=self.in_channels, | |
| num_head_channels=self.num_attention_heads, | |
| num_groups=resnet_groups, | |
| dtype=self.dtype, | |
| ) | |
| attentions.append(attn_block) | |
| res_block = FlaxResnetBlock2D( | |
| in_channels=self.in_channels, | |
| out_channels=self.in_channels, | |
| dropout=self.dropout, | |
| groups=resnet_groups, | |
| dtype=self.dtype, | |
| ) | |
| resnets.append(res_block) | |
| self.resnets = resnets | |
| self.attentions = attentions | |
| def __call__(self, hidden_states, deterministic=True): | |
| hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) | |
| for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
| hidden_states = attn(hidden_states) | |
| hidden_states = resnet(hidden_states, deterministic=deterministic) | |
| return hidden_states | |
| class FlaxEncoder(nn.Module): | |
| r""" | |
| Flax Implementation of VAE Encoder. | |
| This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
| subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to | |
| general usage and behavior. | |
| Finally, this model supports inherent JAX features such as: | |
| - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
| - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
| - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
| - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
| Parameters: | |
| in_channels (:obj:`int`, *optional*, defaults to 3): | |
| Input channels | |
| out_channels (:obj:`int`, *optional*, defaults to 3): | |
| Output channels | |
| down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): | |
| DownEncoder block type | |
| block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): | |
| Tuple containing the number of output channels for each block | |
| layers_per_block (:obj:`int`, *optional*, defaults to `2`): | |
| Number of Resnet layer for each block | |
| norm_num_groups (:obj:`int`, *optional*, defaults to `32`): | |
| norm num group | |
| act_fn (:obj:`str`, *optional*, defaults to `silu`): | |
| Activation function | |
| double_z (:obj:`bool`, *optional*, defaults to `False`): | |
| Whether to double the last output channels | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| Parameters `dtype` | |
| """ | |
| in_channels: int = 3 | |
| out_channels: int = 3 | |
| down_block_types: Tuple[str] = ("DownEncoderBlock2D",) | |
| block_out_channels: Tuple[int] = (64,) | |
| layers_per_block: int = 2 | |
| norm_num_groups: int = 32 | |
| act_fn: str = "silu" | |
| double_z: bool = False | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| block_out_channels = self.block_out_channels | |
| # in | |
| self.conv_in = nn.Conv( | |
| block_out_channels[0], | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| # downsampling | |
| down_blocks = [] | |
| output_channel = block_out_channels[0] | |
| for i, _ in enumerate(self.down_block_types): | |
| input_channel = output_channel | |
| output_channel = block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| down_block = FlaxDownEncoderBlock2D( | |
| in_channels=input_channel, | |
| out_channels=output_channel, | |
| num_layers=self.layers_per_block, | |
| resnet_groups=self.norm_num_groups, | |
| add_downsample=not is_final_block, | |
| dtype=self.dtype, | |
| ) | |
| down_blocks.append(down_block) | |
| self.down_blocks = down_blocks | |
| # middle | |
| self.mid_block = FlaxUNetMidBlock2D( | |
| in_channels=block_out_channels[-1], | |
| resnet_groups=self.norm_num_groups, | |
| num_attention_heads=None, | |
| dtype=self.dtype, | |
| ) | |
| # end | |
| conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels | |
| self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) | |
| self.conv_out = nn.Conv( | |
| conv_out_channels, | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| def __call__(self, sample, deterministic: bool = True): | |
| # in | |
| sample = self.conv_in(sample) | |
| # downsampling | |
| for block in self.down_blocks: | |
| sample = block(sample, deterministic=deterministic) | |
| # middle | |
| sample = self.mid_block(sample, deterministic=deterministic) | |
| # end | |
| sample = self.conv_norm_out(sample) | |
| sample = nn.swish(sample) | |
| sample = self.conv_out(sample) | |
| return sample | |
| class FlaxDecoder(nn.Module): | |
| r""" | |
| Flax Implementation of VAE Decoder. | |
| This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
| subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to | |
| general usage and behavior. | |
| Finally, this model supports inherent JAX features such as: | |
| - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
| - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
| - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
| - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
| Parameters: | |
| in_channels (:obj:`int`, *optional*, defaults to 3): | |
| Input channels | |
| out_channels (:obj:`int`, *optional*, defaults to 3): | |
| Output channels | |
| up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): | |
| UpDecoder block type | |
| block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): | |
| Tuple containing the number of output channels for each block | |
| layers_per_block (:obj:`int`, *optional*, defaults to `2`): | |
| Number of Resnet layer for each block | |
| norm_num_groups (:obj:`int`, *optional*, defaults to `32`): | |
| norm num group | |
| act_fn (:obj:`str`, *optional*, defaults to `silu`): | |
| Activation function | |
| double_z (:obj:`bool`, *optional*, defaults to `False`): | |
| Whether to double the last output channels | |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): | |
| parameters `dtype` | |
| """ | |
| in_channels: int = 3 | |
| out_channels: int = 3 | |
| up_block_types: Tuple[str] = ("UpDecoderBlock2D",) | |
| block_out_channels: int = (64,) | |
| layers_per_block: int = 2 | |
| norm_num_groups: int = 32 | |
| act_fn: str = "silu" | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| block_out_channels = self.block_out_channels | |
| # z to block_in | |
| self.conv_in = nn.Conv( | |
| block_out_channels[-1], | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| # middle | |
| self.mid_block = FlaxUNetMidBlock2D( | |
| in_channels=block_out_channels[-1], | |
| resnet_groups=self.norm_num_groups, | |
| num_attention_heads=None, | |
| dtype=self.dtype, | |
| ) | |
| # upsampling | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| output_channel = reversed_block_out_channels[0] | |
| up_blocks = [] | |
| for i, _ in enumerate(self.up_block_types): | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| up_block = FlaxUpDecoderBlock2D( | |
| in_channels=prev_output_channel, | |
| out_channels=output_channel, | |
| num_layers=self.layers_per_block + 1, | |
| resnet_groups=self.norm_num_groups, | |
| add_upsample=not is_final_block, | |
| dtype=self.dtype, | |
| ) | |
| up_blocks.append(up_block) | |
| prev_output_channel = output_channel | |
| self.up_blocks = up_blocks | |
| # end | |
| self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) | |
| self.conv_out = nn.Conv( | |
| self.out_channels, | |
| kernel_size=(3, 3), | |
| strides=(1, 1), | |
| padding=((1, 1), (1, 1)), | |
| dtype=self.dtype, | |
| ) | |
| def __call__(self, sample, deterministic: bool = True): | |
| # z to block_in | |
| sample = self.conv_in(sample) | |
| # middle | |
| sample = self.mid_block(sample, deterministic=deterministic) | |
| # upsampling | |
| for block in self.up_blocks: | |
| sample = block(sample, deterministic=deterministic) | |
| sample = self.conv_norm_out(sample) | |
| sample = nn.swish(sample) | |
| sample = self.conv_out(sample) | |
| return sample | |
| class FlaxDiagonalGaussianDistribution(object): | |
| def __init__(self, parameters, deterministic=False): | |
| # Last axis to account for channels-last | |
| self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) | |
| self.logvar = jnp.clip(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = jnp.exp(0.5 * self.logvar) | |
| self.var = jnp.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = jnp.zeros_like(self.mean) | |
| def sample(self, key): | |
| return self.mean + self.std * jax.random.normal(key, self.mean.shape) | |
| def kl(self, other=None): | |
| if self.deterministic: | |
| return jnp.array([0.0]) | |
| if other is None: | |
| return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) | |
| return 0.5 * jnp.sum( | |
| jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, | |
| axis=[1, 2, 3], | |
| ) | |
| def nll(self, sample, axis=[1, 2, 3]): | |
| if self.deterministic: | |
| return jnp.array([0.0]) | |
| logtwopi = jnp.log(2.0 * jnp.pi) | |
| return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) | |
| def mode(self): | |
| return self.mean | |
| class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): | |
| r""" | |
| Flax implementation of a VAE model with KL loss for decoding latent representations. | |
| This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods | |
| implemented for all models (such as downloading or saving). | |
| This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
| subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its | |
| general usage and behavior. | |
| Inherent JAX features such as the following are supported: | |
| - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
| - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
| - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
| - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
| Parameters: | |
| in_channels (`int`, *optional*, defaults to 3): | |
| Number of channels in the input image. | |
| out_channels (`int`, *optional*, defaults to 3): | |
| Number of channels in the output. | |
| down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): | |
| Tuple of downsample block types. | |
| up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): | |
| Tuple of upsample block types. | |
| block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): | |
| Tuple of block output channels. | |
| layers_per_block (`int`, *optional*, defaults to `2`): | |
| Number of ResNet layer for each block. | |
| act_fn (`str`, *optional*, defaults to `silu`): | |
| The activation function to use. | |
| latent_channels (`int`, *optional*, defaults to `4`): | |
| Number of channels in the latent space. | |
| norm_num_groups (`int`, *optional*, defaults to `32`): | |
| The number of groups for normalization. | |
| sample_size (`int`, *optional*, defaults to 32): | |
| Sample input size. | |
| scaling_factor (`float`, *optional*, defaults to 0.18215): | |
| The component-wise standard deviation of the trained latent space computed using the first batch of the | |
| training set. This is used to scale the latent space to have unit variance when training the diffusion | |
| model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the | |
| diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 | |
| / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image | |
| Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. | |
| dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): | |
| The `dtype` of the parameters. | |
| """ | |
| in_channels: int = 3 | |
| out_channels: int = 3 | |
| down_block_types: Tuple[str] = ("DownEncoderBlock2D",) | |
| up_block_types: Tuple[str] = ("UpDecoderBlock2D",) | |
| block_out_channels: Tuple[int] = (64,) | |
| layers_per_block: int = 1 | |
| act_fn: str = "silu" | |
| latent_channels: int = 4 | |
| norm_num_groups: int = 32 | |
| sample_size: int = 32 | |
| scaling_factor: float = 0.18215 | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| self.encoder = FlaxEncoder( | |
| in_channels=self.config.in_channels, | |
| out_channels=self.config.latent_channels, | |
| down_block_types=self.config.down_block_types, | |
| block_out_channels=self.config.block_out_channels, | |
| layers_per_block=self.config.layers_per_block, | |
| act_fn=self.config.act_fn, | |
| norm_num_groups=self.config.norm_num_groups, | |
| double_z=True, | |
| dtype=self.dtype, | |
| ) | |
| self.decoder = FlaxDecoder( | |
| in_channels=self.config.latent_channels, | |
| out_channels=self.config.out_channels, | |
| up_block_types=self.config.up_block_types, | |
| block_out_channels=self.config.block_out_channels, | |
| layers_per_block=self.config.layers_per_block, | |
| norm_num_groups=self.config.norm_num_groups, | |
| act_fn=self.config.act_fn, | |
| dtype=self.dtype, | |
| ) | |
| self.quant_conv = nn.Conv( | |
| 2 * self.config.latent_channels, | |
| kernel_size=(1, 1), | |
| strides=(1, 1), | |
| padding="VALID", | |
| dtype=self.dtype, | |
| ) | |
| self.post_quant_conv = nn.Conv( | |
| self.config.latent_channels, | |
| kernel_size=(1, 1), | |
| strides=(1, 1), | |
| padding="VALID", | |
| dtype=self.dtype, | |
| ) | |
| def init_weights(self, rng: jax.Array) -> FrozenDict: | |
| # init input tensors | |
| sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) | |
| sample = jnp.zeros(sample_shape, dtype=jnp.float32) | |
| params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) | |
| rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} | |
| return self.init(rngs, sample)["params"] | |
| def encode(self, sample, deterministic: bool = True, return_dict: bool = True): | |
| sample = jnp.transpose(sample, (0, 2, 3, 1)) | |
| hidden_states = self.encoder(sample, deterministic=deterministic) | |
| moments = self.quant_conv(hidden_states) | |
| posterior = FlaxDiagonalGaussianDistribution(moments) | |
| if not return_dict: | |
| return (posterior,) | |
| return FlaxAutoencoderKLOutput(latent_dist=posterior) | |
| def decode(self, latents, deterministic: bool = True, return_dict: bool = True): | |
| if latents.shape[-1] != self.config.latent_channels: | |
| latents = jnp.transpose(latents, (0, 2, 3, 1)) | |
| hidden_states = self.post_quant_conv(latents) | |
| hidden_states = self.decoder(hidden_states, deterministic=deterministic) | |
| hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) | |
| if not return_dict: | |
| return (hidden_states,) | |
| return FlaxDecoderOutput(sample=hidden_states) | |
| def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): | |
| posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) | |
| if sample_posterior: | |
| rng = self.make_rng("gaussian") | |
| hidden_states = posterior.latent_dist.sample(rng) | |
| else: | |
| hidden_states = posterior.latent_dist.mode() | |
| sample = self.decode(hidden_states, return_dict=return_dict).sample | |
| if not return_dict: | |
| return (sample,) | |
| return FlaxDecoderOutput(sample=sample) | |