|
|
|
from functools import partial, cached_property |
|
|
|
import jax |
|
from diffusers import FlaxAutoencoderKL |
|
from einops import rearrange |
|
from flax import struct |
|
|
|
from jaxtyping import Array, PyTree, Key, Float, Shaped, Int, UInt8, jaxtyped |
|
from typeguard import typechecked |
|
from functools import partial |
|
typecheck = partial(jaxtyped, typechecker=typechecked) |
|
|
|
import jax.numpy as jnp |
|
|
|
@struct.dataclass |
|
class StableVAE: |
|
params: PyTree[Float[Array, "..."]] |
|
module: FlaxAutoencoderKL = struct.field(pytree_node=False) |
|
|
|
@classmethod |
|
def create(cls) -> "VAE": |
|
|
|
|
|
|
|
module, params = FlaxAutoencoderKL.from_pretrained( |
|
"pcuenq/sd-vae-ft-mse-flax" |
|
) |
|
params = jax.device_get(params) |
|
return cls( |
|
params=params, |
|
module=module, |
|
) |
|
|
|
@partial(jax.jit, static_argnames="scale") |
|
def encode( |
|
self, key: Key[Array, ""], images: Float[Array, "b h w 3"], scale: bool = True |
|
) -> Float[Array, "b lh lw 4"]: |
|
images = rearrange(images, "b h w c -> b c h w") |
|
latents = self.module.apply( |
|
{"params": self.params}, images, method=self.module.encode |
|
).latent_dist.sample(key) |
|
if scale: |
|
|
|
mean = jnp.array([1.1743683, -0.4075004, 0.4488433, 0.6760574]) |
|
std = jnp.array([4.9045634, 5.4250283, 3.9848266, 4.010667]) |
|
|
|
latents = latents * 1.0/std |
|
return latents |
|
|
|
@partial(jax.jit, static_argnames="scale") |
|
def decode( |
|
self, latents: Float[Array, "b lh lw 4"], scale: bool = True |
|
) -> Float[Array, "b h w 3"]: |
|
if scale: |
|
|
|
mean = jnp.array([1.1743683, -0.4075004, 0.4488433, 0.6760574]) |
|
std = jnp.array([4.9045634, 5.4250283, 3.9848266, 4.010667]) |
|
latents = latents * std |
|
|
|
|
|
|
|
images = self.module.apply( |
|
{"params": self.params}, latents, method=self.module.decode |
|
).sample |
|
|
|
images = rearrange(images, "b c h w -> b h w c") |
|
return images |
|
|
|
@cached_property |
|
def downscale_factor(self) -> int: |
|
return 2 ** (len(self.module.block_out_channels) - 1) |
|
|