File size: 3,400 Bytes
51e86ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

from functools import partial, cached_property

import jax
from diffusers import FlaxAutoencoderKL
from einops import rearrange
from flax import struct

from jaxtyping import Array, PyTree, Key, Float, Shaped, Int, UInt8, jaxtyped
from typeguard import typechecked
from functools import partial
typecheck = partial(jaxtyped, typechecker=typechecked)

import jax.numpy as jnp

import pickle
def load_stats(path="stats.pkl"):
    with open(path, "rb") as f:
        return pickle.load(f)

try:
    stats = load_stats()#mean, zca
except:
    pass
@struct.dataclass
class StableVAE:
    params: PyTree[Float[Array, "..."]]
    module: FlaxAutoencoderKL = struct.field(pytree_node=False)

    @classmethod
    def create(cls) -> "VAE":
        # module, params = FlaxAutoencoderKL.from_pretrained(
        #     "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae"
        # )
        module, params = FlaxAutoencoderKL.from_pretrained(
            "pcuenq/sd-vae-ft-mse-flax"
        )
        params = jax.device_get(params)
        return cls(
            params=params,
            module=module,
        )

    @partial(jax.jit, static_argnames="scale")
    def encode(
        self, key: Key[Array, ""], images: Float[Array, "b h w 3"], scale: bool = True
    ) -> Float[Array, "b lh lw 4"]:
        images = rearrange(images, "b h w c -> b c h w")
        latents = self.module.apply(
            {"params": self.params}, images, method=self.module.encode
        ).latent_dist.sample(key)
        
#        return latents
        B, H, W, C = latents.shape
        latents_whitened = jnp.zeros(latents.shape)
        for c in range(C):
            x = latents[:, :, :, c].reshape(B, -1)#We are channels last probably
            mean, zca = stats[c]

            x_centered = x - mean
            x_whitened = (zca @ x_centered.T).T
            latents_whitened = latents_whitened.at[:, :, :, c].set(x_whitened.reshape(B, H, W))

#        if scale:
#            latents *= self.module.config.scaling_factor
        return latents_whitened

    @partial(jax.jit, static_argnames="scale")
    def decode(
        self, latents: Float[Array, "b lh lw 4"], scale: bool = True
    ) -> Float[Array, "b h w 3"]:
        #if scale:
        #    latents /= self.module.config.scaling_factor
        
#        latents = latents.reshape(1)#256x32x32x4
        #Not sure these latents are correct shape, but whatever
        B, H, W, C = latents.shape
        latents_unwhitened = jnp.zeros(latents.shape)
        
        for c in range(C):
            x = latents[:, :, :, c].reshape(B, -1)
            mean, zca = stats[c]
            zca_inv = jnp.linalg.inv(zca)

            x_unwhitened = (zca_inv @ x.T).T + mean
            latents_unwhitened = latents_unwhitened.at[:, : ,: ,c].set(x_unwhitened.reshape(B,H,W))
            
        latents = latents_unwhitened
        #I don't think you need to sample to encode and sample to decode.
        images = self.module.apply(
            {"params": self.params}, latents, method=self.module.decode
        ).sample

        # convert to channels-last
        #This actually just converts to channels FIRST, which is needed to convert to image
        images = rearrange(images, "b c h w -> b h w c")
        return images

    @cached_property
    def downscale_factor(self) -> int:
        return 2 ** (len(self.module.block_out_channels) - 1)