Upload liquid_flow/vae_wrapper.py
Browse files- liquid_flow/vae_wrapper.py +28 -58
liquid_flow/vae_wrapper.py
CHANGED
|
@@ -1,53 +1,33 @@
|
|
| 1 |
"""
|
| 2 |
-
VAE Wrappers —
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
Model: madebyollin/taesd (335K downloads on HF)
|
| 13 |
"""
|
| 14 |
|
| 15 |
import torch
|
| 16 |
-
import torch.nn as nn
|
| 17 |
-
import torch.nn.functional as F
|
| 18 |
-
from typing import Optional
|
| 19 |
|
| 20 |
|
| 21 |
class TAESDWrapper:
|
| 22 |
"""
|
| 23 |
Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
|
| 24 |
|
| 25 |
-
TAESD
|
| 26 |
-
|
| 27 |
-
- Latent dim: 4 channels @ 8x compression
|
| 28 |
-
- Extremely fast encode/decode
|
| 29 |
-
- Works on CPU — no GPU needed
|
| 30 |
-
- Perfect for Colab/Kaggle free tier
|
| 31 |
|
| 32 |
-
Model
|
| 33 |
"""
|
| 34 |
|
| 35 |
-
def __init__(self, device='cpu'):
|
| 36 |
-
self.device = device
|
| 37 |
-
self.model = None
|
| 38 |
-
|
| 39 |
-
@staticmethod
|
| 40 |
-
def is_available():
|
| 41 |
-
"""Check if TAESD can be loaded."""
|
| 42 |
-
try:
|
| 43 |
-
from diffusers import AutoencoderTiny
|
| 44 |
-
return True
|
| 45 |
-
except ImportError:
|
| 46 |
-
return False
|
| 47 |
-
|
| 48 |
@staticmethod
|
| 49 |
def load(device='cpu'):
|
| 50 |
-
"""Load TAESD model."""
|
| 51 |
from diffusers import AutoencoderTiny
|
| 52 |
model = AutoencoderTiny.from_pretrained(
|
| 53 |
"madebyollin/taesd",
|
|
@@ -57,25 +37,19 @@ class TAESDWrapper:
|
|
| 57 |
model.eval()
|
| 58 |
return model
|
| 59 |
|
| 60 |
-
@staticmethod
|
| 61 |
-
def get_latent_shape(image_size):
|
| 62 |
-
"""Get latent spatial size given image size (8x compression)."""
|
| 63 |
-
return image_size // 8
|
| 64 |
-
|
| 65 |
@staticmethod
|
| 66 |
def encode(vae, x):
|
| 67 |
"""
|
| 68 |
Encode image to latent.
|
| 69 |
Args:
|
| 70 |
-
vae:
|
| 71 |
x: [B, 3, H, W] images in [-1, 1]
|
| 72 |
Returns:
|
| 73 |
-
z: [B, 4, H/8, W/8]
|
| 74 |
"""
|
| 75 |
with torch.no_grad():
|
| 76 |
-
|
| 77 |
-
z =
|
| 78 |
-
z = z * vae.config.scaling_factor
|
| 79 |
return z
|
| 80 |
|
| 81 |
@staticmethod
|
|
@@ -83,34 +57,30 @@ class TAESDWrapper:
|
|
| 83 |
"""
|
| 84 |
Decode latent to image.
|
| 85 |
Args:
|
| 86 |
-
vae:
|
| 87 |
-
z: [B, 4, H/8, W/8]
|
| 88 |
Returns:
|
| 89 |
x: [B, 3, H, W] images in [-1, 1]
|
| 90 |
"""
|
| 91 |
with torch.no_grad():
|
| 92 |
-
z = z / vae.config.scaling_factor
|
| 93 |
x = vae.decode(z).sample
|
| 94 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
class SDVAEWrapper:
|
| 98 |
"""
|
| 99 |
Wrapper for Stability AI VAE (sd-vae-ft-mse).
|
| 100 |
|
| 101 |
-
|
| 102 |
-
- ~84M parameters
|
| 103 |
-
- Latent dim: 4 channels @ 8x compression
|
| 104 |
-
- Higher quality reconstruction than TAESD
|
| 105 |
-
- Requires GPU for reasonable speed
|
| 106 |
|
| 107 |
-
Model
|
| 108 |
"""
|
| 109 |
|
| 110 |
-
def __init__(self, device='cpu'):
|
| 111 |
-
self.device = device
|
| 112 |
-
self.model = None
|
| 113 |
-
|
| 114 |
@staticmethod
|
| 115 |
def load(device='cpu'):
|
| 116 |
"""Load SD VAE model."""
|
|
@@ -125,7 +95,7 @@ class SDVAEWrapper:
|
|
| 125 |
|
| 126 |
@staticmethod
|
| 127 |
def encode(vae, x):
|
| 128 |
-
"""Encode image to latent."""
|
| 129 |
with torch.no_grad():
|
| 130 |
posterior = vae.encode(x).latent_dist
|
| 131 |
z = posterior.sample()
|
|
@@ -134,7 +104,7 @@ class SDVAEWrapper:
|
|
| 134 |
|
| 135 |
@staticmethod
|
| 136 |
def decode(vae, z):
|
| 137 |
-
"""Decode latent to image."""
|
| 138 |
with torch.no_grad():
|
| 139 |
z = z / vae.config.scaling_factor
|
| 140 |
x = vae.decode(z).sample
|
|
|
|
| 1 |
"""
|
| 2 |
+
VAE Wrappers — corrected for actual TAESD and SD-VAE APIs.
|
| 3 |
|
| 4 |
+
TAESD (AutoencoderTiny):
|
| 5 |
+
- encode(x) returns AutoencoderTinyOutput with .latents (no sampling)
|
| 6 |
+
- scaling_factor = 1.0 (no scaling needed)
|
| 7 |
+
- decode(z) returns DecoderOutput with .sample
|
| 8 |
|
| 9 |
+
SD-VAE (AutoencoderKL):
|
| 10 |
+
- encode(x) returns AutoEncoderKLOutput with .latent_dist
|
| 11 |
+
- scaling_factor = 0.18215
|
| 12 |
+
- decode(z) returns DecoderOutput with .sample
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
import torch
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class TAESDWrapper:
|
| 19 |
"""
|
| 20 |
Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
|
| 21 |
|
| 22 |
+
Key: TAESD uses .latents directly (deterministic encoder, no sampling).
|
| 23 |
+
scaling_factor = 1.0, so no scaling needed.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
Model: madebyollin/taesd (~2.5M params, 9.8MB)
|
| 26 |
"""
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
@staticmethod
|
| 29 |
def load(device='cpu'):
|
| 30 |
+
"""Load TAESD model from HuggingFace."""
|
| 31 |
from diffusers import AutoencoderTiny
|
| 32 |
model = AutoencoderTiny.from_pretrained(
|
| 33 |
"madebyollin/taesd",
|
|
|
|
| 37 |
model.eval()
|
| 38 |
return model
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
@staticmethod
|
| 41 |
def encode(vae, x):
|
| 42 |
"""
|
| 43 |
Encode image to latent.
|
| 44 |
Args:
|
| 45 |
+
vae: AutoencoderTiny model
|
| 46 |
x: [B, 3, H, W] images in [-1, 1]
|
| 47 |
Returns:
|
| 48 |
+
z: [B, 4, H/8, W/8] latents
|
| 49 |
"""
|
| 50 |
with torch.no_grad():
|
| 51 |
+
# TAESD returns .latents directly (no latent_dist)
|
| 52 |
+
z = vae.encode(x).latents
|
|
|
|
| 53 |
return z
|
| 54 |
|
| 55 |
@staticmethod
|
|
|
|
| 57 |
"""
|
| 58 |
Decode latent to image.
|
| 59 |
Args:
|
| 60 |
+
vae: AutoencoderTiny model
|
| 61 |
+
z: [B, 4, H/8, W/8] latents
|
| 62 |
Returns:
|
| 63 |
x: [B, 3, H, W] images in [-1, 1]
|
| 64 |
"""
|
| 65 |
with torch.no_grad():
|
|
|
|
| 66 |
x = vae.decode(z).sample
|
| 67 |
return x
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def get_latent_shape(image_size):
|
| 71 |
+
"""Get latent spatial size (8x compression)."""
|
| 72 |
+
return image_size // 8
|
| 73 |
|
| 74 |
|
| 75 |
class SDVAEWrapper:
|
| 76 |
"""
|
| 77 |
Wrapper for Stability AI VAE (sd-vae-ft-mse).
|
| 78 |
|
| 79 |
+
Key: Uses .latent_dist.sample() and scaling_factor=0.18215.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
Model: stabilityai/sd-vae-ft-mse (~84M params)
|
| 82 |
"""
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
@staticmethod
|
| 85 |
def load(device='cpu'):
|
| 86 |
"""Load SD VAE model."""
|
|
|
|
| 95 |
|
| 96 |
@staticmethod
|
| 97 |
def encode(vae, x):
|
| 98 |
+
"""Encode image to latent (with scaling)."""
|
| 99 |
with torch.no_grad():
|
| 100 |
posterior = vae.encode(x).latent_dist
|
| 101 |
z = posterior.sample()
|
|
|
|
| 104 |
|
| 105 |
@staticmethod
|
| 106 |
def decode(vae, z):
|
| 107 |
+
"""Decode latent to image (with unscaling)."""
|
| 108 |
with torch.no_grad():
|
| 109 |
z = z / vae.config.scaling_factor
|
| 110 |
x = vae.decode(z).sample
|