krystv commited on
Commit
3798d56
·
verified ·
1 Parent(s): 68363aa

Upload liquid_flow/vae_wrapper.py

Browse files
Files changed (1) hide show
  1. liquid_flow/vae_wrapper.py +28 -58
liquid_flow/vae_wrapper.py CHANGED
@@ -1,53 +1,33 @@
1
  """
2
- VAE Wrappers — compatible VAE interfaces for LiquidFlow.
3
 
4
- Supports two VAE backends:
5
- 1. TAESD (Tiny AutoEncoder for SD): < 1M params, extremely fast, perfect for mobile
6
- 2. SD-VAE (Stability AI VAE): Higher quality, 84M params, standard for SD pipelines
 
7
 
8
- TAESD is the DEFAULT for LiquidFlow — it's designed to be lightweight and
9
- fast enough for Colab/Kaggle free tier.
10
-
11
- Paper reference: "Tiny AutoEncoder for Stable Diffusion" (madebyollin/taesd)
12
- Model: madebyollin/taesd (335K downloads on HF)
13
  """
14
 
15
  import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- from typing import Optional
19
 
20
 
21
  class TAESDWrapper:
22
  """
23
  Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
24
 
25
- TAESD properties:
26
- - ~1M parameters (vs 84M for SD VAE)
27
- - Latent dim: 4 channels @ 8x compression
28
- - Extremely fast encode/decode
29
- - Works on CPU — no GPU needed
30
- - Perfect for Colab/Kaggle free tier
31
 
32
- Model on HF: madebyollin/taesd
33
  """
34
 
35
- def __init__(self, device='cpu'):
36
- self.device = device
37
- self.model = None
38
-
39
- @staticmethod
40
- def is_available():
41
- """Check if TAESD can be loaded."""
42
- try:
43
- from diffusers import AutoencoderTiny
44
- return True
45
- except ImportError:
46
- return False
47
-
48
  @staticmethod
49
  def load(device='cpu'):
50
- """Load TAESD model."""
51
  from diffusers import AutoencoderTiny
52
  model = AutoencoderTiny.from_pretrained(
53
  "madebyollin/taesd",
@@ -57,25 +37,19 @@ class TAESDWrapper:
57
  model.eval()
58
  return model
59
 
60
- @staticmethod
61
- def get_latent_shape(image_size):
62
- """Get latent spatial size given image size (8x compression)."""
63
- return image_size // 8
64
-
65
  @staticmethod
66
  def encode(vae, x):
67
  """
68
  Encode image to latent.
69
  Args:
70
- vae: TAESD model
71
  x: [B, 3, H, W] images in [-1, 1]
72
  Returns:
73
- z: [B, 4, H/8, W/8]
74
  """
75
  with torch.no_grad():
76
- posterior = vae.encode(x).latent_dist
77
- z = posterior.sample()
78
- z = z * vae.config.scaling_factor
79
  return z
80
 
81
  @staticmethod
@@ -83,34 +57,30 @@ class TAESDWrapper:
83
  """
84
  Decode latent to image.
85
  Args:
86
- vae: TAESD model
87
- z: [B, 4, H/8, W/8]
88
  Returns:
89
  x: [B, 3, H, W] images in [-1, 1]
90
  """
91
  with torch.no_grad():
92
- z = z / vae.config.scaling_factor
93
  x = vae.decode(z).sample
94
  return x
 
 
 
 
 
95
 
96
 
97
  class SDVAEWrapper:
98
  """
99
  Wrapper for Stability AI VAE (sd-vae-ft-mse).
100
 
101
- Properties:
102
- - ~84M parameters
103
- - Latent dim: 4 channels @ 8x compression
104
- - Higher quality reconstruction than TAESD
105
- - Requires GPU for reasonable speed
106
 
107
- Model on HF: stabilityai/sd-vae-ft-mse
108
  """
109
 
110
- def __init__(self, device='cpu'):
111
- self.device = device
112
- self.model = None
113
-
114
  @staticmethod
115
  def load(device='cpu'):
116
  """Load SD VAE model."""
@@ -125,7 +95,7 @@ class SDVAEWrapper:
125
 
126
  @staticmethod
127
  def encode(vae, x):
128
- """Encode image to latent."""
129
  with torch.no_grad():
130
  posterior = vae.encode(x).latent_dist
131
  z = posterior.sample()
@@ -134,7 +104,7 @@ class SDVAEWrapper:
134
 
135
  @staticmethod
136
  def decode(vae, z):
137
- """Decode latent to image."""
138
  with torch.no_grad():
139
  z = z / vae.config.scaling_factor
140
  x = vae.decode(z).sample
 
1
  """
2
+ VAE Wrappers — corrected for actual TAESD and SD-VAE APIs.
3
 
4
+ TAESD (AutoencoderTiny):
5
+ - encode(x) returns AutoencoderTinyOutput with .latents (no sampling)
6
+ - scaling_factor = 1.0 (no scaling needed)
7
+ - decode(z) returns DecoderOutput with .sample
8
 
9
+ SD-VAE (AutoencoderKL):
10
+ - encode(x) returns AutoEncoderKLOutput with .latent_dist
11
+ - scaling_factor = 0.18215
12
+ - decode(z) returns DecoderOutput with .sample
 
13
  """
14
 
15
  import torch
 
 
 
16
 
17
 
18
  class TAESDWrapper:
19
  """
20
  Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
21
 
22
+ Key: TAESD uses .latents directly (deterministic encoder, no sampling).
23
+ scaling_factor = 1.0, so no scaling needed.
 
 
 
 
24
 
25
+ Model: madebyollin/taesd (~2.5M params, 9.8MB)
26
  """
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @staticmethod
29
  def load(device='cpu'):
30
+ """Load TAESD model from HuggingFace."""
31
  from diffusers import AutoencoderTiny
32
  model = AutoencoderTiny.from_pretrained(
33
  "madebyollin/taesd",
 
37
  model.eval()
38
  return model
39
 
 
 
 
 
 
40
  @staticmethod
41
  def encode(vae, x):
42
  """
43
  Encode image to latent.
44
  Args:
45
+ vae: AutoencoderTiny model
46
  x: [B, 3, H, W] images in [-1, 1]
47
  Returns:
48
+ z: [B, 4, H/8, W/8] latents
49
  """
50
  with torch.no_grad():
51
+ # TAESD returns .latents directly (no latent_dist)
52
+ z = vae.encode(x).latents
 
53
  return z
54
 
55
  @staticmethod
 
57
  """
58
  Decode latent to image.
59
  Args:
60
+ vae: AutoencoderTiny model
61
+ z: [B, 4, H/8, W/8] latents
62
  Returns:
63
  x: [B, 3, H, W] images in [-1, 1]
64
  """
65
  with torch.no_grad():
 
66
  x = vae.decode(z).sample
67
  return x
68
+
69
+ @staticmethod
70
+ def get_latent_shape(image_size):
71
+ """Get latent spatial size (8x compression)."""
72
+ return image_size // 8
73
 
74
 
75
  class SDVAEWrapper:
76
  """
77
  Wrapper for Stability AI VAE (sd-vae-ft-mse).
78
 
79
+ Key: Uses .latent_dist.sample() and scaling_factor=0.18215.
 
 
 
 
80
 
81
+ Model: stabilityai/sd-vae-ft-mse (~84M params)
82
  """
83
 
 
 
 
 
84
  @staticmethod
85
  def load(device='cpu'):
86
  """Load SD VAE model."""
 
95
 
96
  @staticmethod
97
  def encode(vae, x):
98
+ """Encode image to latent (with scaling)."""
99
  with torch.no_grad():
100
  posterior = vae.encode(x).latent_dist
101
  z = posterior.sample()
 
104
 
105
  @staticmethod
106
  def decode(vae, z):
107
+ """Decode latent to image (with unscaling)."""
108
  with torch.no_grad():
109
  z = z / vae.config.scaling_factor
110
  x = vae.decode(z).sample