File size: 3,400 Bytes
			
			| 51e86ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | 
from functools import partial, cached_property
import jax
from diffusers import FlaxAutoencoderKL
from einops import rearrange
from flax import struct
from jaxtyping import Array, PyTree, Key, Float, Shaped, Int, UInt8, jaxtyped
from typeguard import typechecked
from functools import partial
typecheck = partial(jaxtyped, typechecker=typechecked)
import jax.numpy as jnp
import pickle
def load_stats(path="stats.pkl"):
    with open(path, "rb") as f:
        return pickle.load(f)
try:
    stats = load_stats()#mean, zca
except:
    pass
@struct.dataclass
class StableVAE:
    params: PyTree[Float[Array, "..."]]
    module: FlaxAutoencoderKL = struct.field(pytree_node=False)
    @classmethod
    def create(cls) -> "VAE":
        # module, params = FlaxAutoencoderKL.from_pretrained(
        #     "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae"
        # )
        module, params = FlaxAutoencoderKL.from_pretrained(
            "pcuenq/sd-vae-ft-mse-flax"
        )
        params = jax.device_get(params)
        return cls(
            params=params,
            module=module,
        )
    @partial(jax.jit, static_argnames="scale")
    def encode(
        self, key: Key[Array, ""], images: Float[Array, "b h w 3"], scale: bool = True
    ) -> Float[Array, "b lh lw 4"]:
        images = rearrange(images, "b h w c -> b c h w")
        latents = self.module.apply(
            {"params": self.params}, images, method=self.module.encode
        ).latent_dist.sample(key)
        
#        return latents
        B, H, W, C = latents.shape
        latents_whitened = jnp.zeros(latents.shape)
        for c in range(C):
            x = latents[:, :, :, c].reshape(B, -1)#We are channels last probably
            mean, zca = stats[c]
            x_centered = x - mean
            x_whitened = (zca @ x_centered.T).T
            latents_whitened = latents_whitened.at[:, :, :, c].set(x_whitened.reshape(B, H, W))
#        if scale:
#            latents *= self.module.config.scaling_factor
        return latents_whitened
    @partial(jax.jit, static_argnames="scale")
    def decode(
        self, latents: Float[Array, "b lh lw 4"], scale: bool = True
    ) -> Float[Array, "b h w 3"]:
        #if scale:
        #    latents /= self.module.config.scaling_factor
        
#        latents = latents.reshape(1)#256x32x32x4
        #Not sure these latents are correct shape, but whatever
        B, H, W, C = latents.shape
        latents_unwhitened = jnp.zeros(latents.shape)
        
        for c in range(C):
            x = latents[:, :, :, c].reshape(B, -1)
            mean, zca = stats[c]
            zca_inv = jnp.linalg.inv(zca)
            x_unwhitened = (zca_inv @ x.T).T + mean
            latents_unwhitened = latents_unwhitened.at[:, : ,: ,c].set(x_unwhitened.reshape(B,H,W))
            
        latents = latents_unwhitened
        #I don't think you need to sample to encode and sample to decode.
        images = self.module.apply(
            {"params": self.params}, latents, method=self.module.decode
        ).sample
        # convert to channels-last
        #This actually just converts to channels FIRST, which is needed to convert to image
        images = rearrange(images, "b c h w -> b h w c")
        return images
    @cached_property
    def downscale_factor(self) -> int:
        return 2 ** (len(self.module.block_out_channels) - 1)
 |