license: apache-2.0
Equivarient 16ch, f8 VAE
AuraEquiVAE is novel autoencoder that fixes multiple problem of existing conventional VAE. First, unlike traditional VAE that has significantly small log-variance, this model admits large noise to the latent.
Next, unlike traditional VAE the latent space is equivariant under Z_2 X Z_2
group operation (Horizonal / Vertical flip).
To understand the equivariance, we give suitable group action to both latent globally but also locally. Meaning, latent represented as Z = (z_1, \cdots, z_n)
and performing the permutation group action g_global
to the tuples such that g_global
is isomorphic to Z_2 x Z_2
group.
But also g_local
to individual z_i
themselves such that g_local
is also isomorphic to Z_2 x Z_2
.
In our case specifically, g_global
corresponds to flips, g_local
corresponds to sign flip on specific latent dimension. changing 2 channel for sign flip for both horizonal, vertical was chosen empirically.
The model has been trained on Mastering VAE Training, and detailed explanation for training could be found there.
How to use
To use the weights, copy paste the VAE implementation.
from ae import VAE
import torch
from PIL import Image
vae = VAE(
resolution=256,
in_channels=3,
ch=256,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_ch
).cuda().bfloat16()
from safetensors.torch import load_file
state_dict = load_file("./vae_epoch_3_step_49501_bf16.pt")
vae.load_state_dict(state_dict)
imgpath = 'contents/lavender.jpg'
img_orig = Image.open(imgpath).convert("RGB")
offset = 128
W = 768
img_orig = img_orig.crop((offset, offset, W + offset, W + offset))
img = transforms.ToTensor()(img_orig).unsqueeze(0).cuda()
img = (img - 0.5) / 0.5
with torch.no_grad():
z = vae.encoder(img)
z = z.clamp(-8.0, 8.0) # this is latent!!
# flip horizontal
z = torch.flip(z, [-1]) # this corresponds to g_global
z[:, -4:-2] = -z[:, -4:-2] # this corresponds to g_local
# flip vertical
z = torch.flip(z, [-2])
z[:, -2:] = -z[:, -2:]
with torch.no_grad():
decz = vae.decoder(z) # this is image!
decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
decimg = (decimg * 255).astype('uint8')
decimg = Image.fromarray(decimg) # PIL image.
Citation
If you find this material useful, please cite:
@misc{Training VQGAN and VAE, with detailed explanation,
author = {Simo Ryu},
title = {Training VQGAN and VAE, with detailed explanation},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}},
}