jbilcke-hf's picture
jbilcke-hf HF Staff
Initial commit with LFS-tracked binary files
01c0e76
raw
history blame
3.43 kB
import torch
from pathlib import Path
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from ..constants import VAE_PATH, PRECISION_TO_TYPE
def load_vae(vae_type,
vae_precision=None,
sample_size=None,
vae_path=None,
logger=None,
device=None
):
"""
Load and configure a Variational Autoencoder (VAE) model.
This function handles loading 3D causal VAE models, including configuration,
weight loading, precision setting, and device placement. It ensures the model
is properly initialized for inference.
Parameters:
vae_type (str): Type identifier for the VAE, must follow '???-*' format for 3D VAEs
vae_precision (str, optional): Desired precision type (e.g., 'fp16', 'fp32').
Uses model's default if not specified.
sample_size (tuple, optional): Input sample dimensions to override config defaults
vae_path (str, optional): Path to VAE model files. Uses predefined path from
VAE_PATH constant if not specified.
logger (logging.Logger, optional): Logger instance for progress/debug messages
device (torch.device, optional): Target device to place the model (e.g., 'cuda' or 'cpu')
Returns:
tuple: Contains:
- vae (AutoencoderKLCausal3D): Loaded and configured VAE model
- vae_path (str): Actual path used to load the VAE
- spatial_compression_ratio (int): Spatial dimension compression factor
- time_compression_ratio (int): Temporal dimension compression factor
Raises:
ValueError: If vae_type does not follow the required 3D VAE format '???-*'
"""
if vae_path is None:
vae_path = VAE_PATH[vae_type]
vae_compress_spec, _, _ = vae_type.split("-")
length = len(vae_compress_spec)
# Process 3D VAE (valid format with 3-character compression spec)
if length == 3:
if logger is not None:
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
config = AutoencoderKLCausal3D.load_config(vae_path)
if sample_size:
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
else:
vae = AutoencoderKLCausal3D.from_config(config)
ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
vae.load_state_dict(vae_ckpt)
spatial_compression_ratio = vae.config.spatial_compression_ratio
time_compression_ratio = vae.config.time_compression_ratio
else:
raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.")
if vae_precision is not None:
vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
vae.requires_grad_(False)
if logger is not None:
logger.info(f"VAE to dtype: {vae.dtype}")
if device is not None:
vae = vae.to(device)
# Ensure model is in evaluation mode (disables dropout/batch norm training behavior)
# Note: Even with dropout rate 0, eval mode is recommended for consistent inference
vae.eval()
return vae, vae_path, spatial_compression_ratio, time_compression_ratio