fal
/

AuraEquiVAE / README.md
cloneofsimo's picture
Update README.md
c4ff943 verified
|
raw
history blame
2.98 kB
metadata
license: apache-2.0

Equivarient 16ch, f8 VAE

AuraEquiVAE is novel autoencoder that fixes multiple problem of existing conventional VAE. First, unlike traditional VAE that has significantly small log-variance, this model admits large noise to the latent. Next, unlike traditional VAE the latent space is equivariant under Z_2 X Z_2 group operation (Horizonal / Vertical flip).

To understand the equivariance, we give suitable group action to both latent globally but also locally. Meaning, latent represented as Z = (z_1, \cdots, z_n) and performing the permutation group action g_global to the tuples such that g_global is isomorphic to Z_2 x Z_2 group. But also g_local to individual z_i themselves such that g_local is also isomorphic to Z_2 x Z_2.

In our case specifically, g_global corresponds to flips, g_local corresponds to sign flip on specific latent dimension. changing 2 channel for sign flip for both horizonal, vertical was chosen empirically.

The model has been trained on Mastering VAE Training, and detailed explanation for training could be found there.

How to use

To use the weights, copy paste the VAE implementation.

from ae import VAE
import torch
from PIL import Image

vae = VAE(
    resolution=256,
    in_channels=3,
    ch=256,
    out_ch=3,
    ch_mult=[1, 2, 4, 4],
    num_res_blocks=2,
    z_ch
).cuda().bfloat16()

from safetensors.torch import load_file
state_dict = load_file("./vae_epoch_3_step_49501_bf16.pt")
vae.load_state_dict(state_dict)

imgpath = 'contents/lavender.jpg'

img_orig = Image.open(imgpath).convert("RGB")
offset = 128
W = 768
img_orig = img_orig.crop((offset, offset, W + offset, W + offset))
img = transforms.ToTensor()(img_orig).unsqueeze(0).cuda()
img = (img - 0.5) / 0.5

with torch.no_grad():
    z = vae.encoder(img)
    z = z.clamp(-8.0, 8.0) # this is latent!!

# flip horizontal
z = torch.flip(z, [-1]) # this corresponds to g_global
z[:, -4:-2] = -z[:, -4:-2] # this corresponds to g_local

# flip vertical
z = torch.flip(z, [-2])
z[:, -2:] = -z[:, -2:]


with torch.no_grad():
    decz = vae.decoder(z) # this is image!

decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
decimg = (decimg * 255).astype('uint8')
decimg = Image.fromarray(decimg) # PIL image.

Citation

If you find this material useful, please cite:

@misc{Training VQGAN and VAE, with detailed explanation,
  author = {Simo Ryu},
  title = {Training VQGAN and VAE, with detailed explanation},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}},
}