LiquidFlow-Gen / liquid_flow /vae_wrapper.py
krystv's picture
Upload liquid_flow/vae_wrapper.py
b9e8cb3 verified
"""
VAE Wrappers — compatible VAE interfaces for LiquidFlow.
Supports two VAE backends:
1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile
2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines
TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and
fast enough for Colab/Kaggle free tier.
Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd)
Model: madebyollin/taesd (335K downloads on HF)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class TAESDWrapper:
"""
Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
TAESD properties:
- ~1M parameters (vs 84M for SD VAE)
- Latent dim: 4 channels @ 8x compression
- Extremely fast encode/decode
- Works on CPU — no GPU needed
- Perfect for Colab/Kaggle free tier
Model on HF: madebyollin/taesd
"""
def __init__(self, device='cpu'):
self.device = device
self.model = None
@staticmethod
def is_available():
"""Check if TAESD can be loaded."""
try:
from diffusers import AutoencoderTiny
return True
except ImportError:
return False
@staticmethod
def load(device='cpu'):
"""Load TAESD model."""
from diffusers import AutoencoderTiny
model = AutoencoderTiny.from_pretrained(
"madebyollin/taesd",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def get_latent_shape(image_size):
"""Get latent spatial size given image size (8x compression)."""
return image_size // 8
@staticmethod
def encode(vae, x):
"""
Encode image to latent.
Args:
vae: TAESD model
x: [B, 3, H, W] images in [-1, 1]
Returns:
z: [B, 4, H/8, W/8]
"""
with torch.no_grad():
posterior = vae.encode(x).latent_dist
z = posterior.sample()
z = z * vae.config.scaling_factor
return z
@staticmethod
def decode(vae, z):
"""
Decode latent to image.
Args:
vae: TAESD model
z: [B, 4, H/8, W/8]
Returns:
x: [B, 3, H, W] images in [-1, 1]
"""
with torch.no_grad():
z = z / vae.config.scaling_factor
x = vae.decode(z).sample
return x
class SDVAEWrapper:
"""
Wrapper for Stability AI VAE (sd-vae-ft-mse).
Properties:
- ~84M parameters
- Latent dim: 4 channels @ 8x compression
- Higher quality reconstruction than TAESD
- Requires GPU for reasonable speed
Model on HF: stabilityai/sd-vae-ft-mse
"""
def __init__(self, device='cpu'):
self.device = device
self.model = None
@staticmethod
def load(device='cpu'):
"""Load SD VAE model."""
from diffusers import AutoencoderKL
model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def encode(vae, x):
"""Encode image to latent."""
with torch.no_grad():
posterior = vae.encode(x).latent_dist
z = posterior.sample()
z = z * vae.config.scaling_factor
return z
@staticmethod
def decode(vae, z):
"""Decode latent to image."""
with torch.no_grad():
z = z / vae.config.scaling_factor
x = vae.decode(z).sample
return x