|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D |
|
from ..constants import VAE_PATH, PRECISION_TO_TYPE |
|
|
|
def load_vae(vae_type: str="884-16c-hy", |
|
vae_precision: str=None, |
|
sample_size: tuple=None, |
|
vae_path: str=None, |
|
logger=None, |
|
device=None |
|
): |
|
"""the fucntion to load the 3D VAE model |
|
|
|
Args: |
|
vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". |
|
vae_precision (str, optional): the precision to load vae. Defaults to None. |
|
sample_size (tuple, optional): the tiling size. Defaults to None. |
|
vae_path (str, optional): the path to vae. Defaults to None. |
|
logger (_type_, optional): logger. Defaults to None. |
|
device (_type_, optional): device to load vae. Defaults to None. |
|
""" |
|
if vae_path is None: |
|
vae_path = VAE_PATH[vae_type] |
|
|
|
if logger is not None: |
|
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") |
|
config = AutoencoderKLCausal3D.load_config(vae_path) |
|
if sample_size: |
|
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) |
|
else: |
|
vae = AutoencoderKLCausal3D.from_config(config) |
|
|
|
vae_ckpt = Path(vae_path) / "pytorch_model.pt" |
|
assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" |
|
|
|
ckpt = torch.load(vae_ckpt, map_location=vae.device) |
|
if "state_dict" in ckpt: |
|
ckpt = ckpt["state_dict"] |
|
if any(k.startswith("vae.") for k in ckpt.keys()): |
|
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} |
|
vae.load_state_dict(ckpt) |
|
|
|
spatial_compression_ratio = vae.config.spatial_compression_ratio |
|
time_compression_ratio = vae.config.time_compression_ratio |
|
|
|
if vae_precision is not None: |
|
vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) |
|
|
|
vae.requires_grad_(False) |
|
|
|
if logger is not None: |
|
logger.info(f"VAE to dtype: {vae.dtype}") |
|
|
|
if device is not None: |
|
vae = vae.to(device) |
|
|
|
vae.eval() |
|
|
|
return vae, vae_path, spatial_compression_ratio, time_compression_ratio |
|
|