File size: 3,400 Bytes
51e86ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from functools import partial, cached_property
import jax
from diffusers import FlaxAutoencoderKL
from einops import rearrange
from flax import struct
from jaxtyping import Array, PyTree, Key, Float, Shaped, Int, UInt8, jaxtyped
from typeguard import typechecked
from functools import partial
typecheck = partial(jaxtyped, typechecker=typechecked)
import jax.numpy as jnp
import pickle
def load_stats(path="stats.pkl"):
with open(path, "rb") as f:
return pickle.load(f)
try:
stats = load_stats()#mean, zca
except:
pass
@struct.dataclass
class StableVAE:
params: PyTree[Float[Array, "..."]]
module: FlaxAutoencoderKL = struct.field(pytree_node=False)
@classmethod
def create(cls) -> "VAE":
# module, params = FlaxAutoencoderKL.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae"
# )
module, params = FlaxAutoencoderKL.from_pretrained(
"pcuenq/sd-vae-ft-mse-flax"
)
params = jax.device_get(params)
return cls(
params=params,
module=module,
)
@partial(jax.jit, static_argnames="scale")
def encode(
self, key: Key[Array, ""], images: Float[Array, "b h w 3"], scale: bool = True
) -> Float[Array, "b lh lw 4"]:
images = rearrange(images, "b h w c -> b c h w")
latents = self.module.apply(
{"params": self.params}, images, method=self.module.encode
).latent_dist.sample(key)
# return latents
B, H, W, C = latents.shape
latents_whitened = jnp.zeros(latents.shape)
for c in range(C):
x = latents[:, :, :, c].reshape(B, -1)#We are channels last probably
mean, zca = stats[c]
x_centered = x - mean
x_whitened = (zca @ x_centered.T).T
latents_whitened = latents_whitened.at[:, :, :, c].set(x_whitened.reshape(B, H, W))
# if scale:
# latents *= self.module.config.scaling_factor
return latents_whitened
@partial(jax.jit, static_argnames="scale")
def decode(
self, latents: Float[Array, "b lh lw 4"], scale: bool = True
) -> Float[Array, "b h w 3"]:
#if scale:
# latents /= self.module.config.scaling_factor
# latents = latents.reshape(1)#256x32x32x4
#Not sure these latents are correct shape, but whatever
B, H, W, C = latents.shape
latents_unwhitened = jnp.zeros(latents.shape)
for c in range(C):
x = latents[:, :, :, c].reshape(B, -1)
mean, zca = stats[c]
zca_inv = jnp.linalg.inv(zca)
x_unwhitened = (zca_inv @ x.T).T + mean
latents_unwhitened = latents_unwhitened.at[:, : ,: ,c].set(x_unwhitened.reshape(B,H,W))
latents = latents_unwhitened
#I don't think you need to sample to encode and sample to decode.
images = self.module.apply(
{"params": self.params}, latents, method=self.module.decode
).sample
# convert to channels-last
#This actually just converts to channels FIRST, which is needed to convert to image
images = rearrange(images, "b c h w -> b h w c")
return images
@cached_property
def downscale_factor(self) -> int:
return 2 ** (len(self.module.block_out_channels) - 1)
|