Equivariant 16ch, f8 VAE
AuraEquiVAE is a novel autoencoder that addresses multiple problems of existing conventional VAEs. First, unlike traditional VAEs that have significantly small log-variance, this model admits large noise to the latent space.
Additionally, unlike traditional VAEs, the latent space is equivariant under Z_2 X Z_2
group operations (Horizontal / Vertical flip).
To understand the equivariance, we apply suitable group actions to both the latent space globally and locally. The latent is represented as Z = (z_1, ..., z_n)
, and we perform a global permutation group action g_global
on the tuples such that g_global
is isomorphic to the Z_2 x Z_2
group.
We also apply a local action g_local
to individual z_i
elements such that g_local
is also isomorphic to the Z_2 x Z_2
group.
In our specific case, g_global
corresponds to flips, while g_local
corresponds to sign flips on specific latent dimensions. Changing 2 channels for sign flips for both horizontal and vertical directions was chosen empirically.
The model has been trained using the approach described in Mastering VAE Training, where detailed explanations for the training process can be found.
How to use
To use the weights, copy paste the VAE implementation.
from ae import VAE
import torch
from PIL import Image
vae = VAE(
resolution=256,
in_channels=3,
ch=256,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16
).cuda().bfloat16()
from safetensors.torch import load_file
state_dict = load_file("./vae_epoch_3_step_49501_bf16.pt")
vae.load_state_dict(state_dict)
imgpath = 'contents/lavender.jpg'
img_orig = Image.open(imgpath).convert("RGB")
offset = 128
W = 768
img_orig = img_orig.crop((offset, offset, W + offset, W + offset))
img = transforms.ToTensor()(img_orig).unsqueeze(0).cuda()
img = (img - 0.5) / 0.5
with torch.no_grad():
z = vae.encoder(img)
z = z.clamp(-8.0, 8.0) # this is latent!!
# flip horizontal
z = torch.flip(z, [-1]) # this corresponds to g_global
z[:, -4:-2] = -z[:, -4:-2] # this corresponds to g_local
# flip vertical
z = torch.flip(z, [-2])
z[:, -2:] = -z[:, -2:]
with torch.no_grad():
decz = vae.decoder(z) # this is image!
decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
decimg = (decimg * 255).astype('uint8')
decimg = Image.fromarray(decimg) # PIL image.
Citation
If you find this model useful, please cite:
@misc{Training VQGAN and VAE, with detailed explanation,
author = {Simo Ryu},
title = {Training VQGAN and VAE, with detailed explanation},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/cloneofsimo/vqgan-training}},
}