|  | import torch | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D | 
					
						
						|  | from ..constants import VAE_PATH, PRECISION_TO_TYPE | 
					
						
						|  |  | 
					
						
						|  | def load_vae(vae_type, | 
					
						
						|  | vae_precision=None, | 
					
						
						|  | sample_size=None, | 
					
						
						|  | vae_path=None, | 
					
						
						|  | logger=None, | 
					
						
						|  | device=None | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Load and configure a Variational Autoencoder (VAE) model. | 
					
						
						|  |  | 
					
						
						|  | This function handles loading 3D causal VAE models, including configuration, | 
					
						
						|  | weight loading, precision setting, and device placement. It ensures the model | 
					
						
						|  | is properly initialized for inference. | 
					
						
						|  |  | 
					
						
						|  | Parameters: | 
					
						
						|  | vae_type (str): Type identifier for the VAE, must follow '???-*' format for 3D VAEs | 
					
						
						|  | vae_precision (str, optional): Desired precision type (e.g., 'fp16', 'fp32'). | 
					
						
						|  | Uses model's default if not specified. | 
					
						
						|  | sample_size (tuple, optional): Input sample dimensions to override config defaults | 
					
						
						|  | vae_path (str, optional): Path to VAE model files. Uses predefined path from | 
					
						
						|  | VAE_PATH constant if not specified. | 
					
						
						|  | logger (logging.Logger, optional): Logger instance for progress/debug messages | 
					
						
						|  | device (torch.device, optional): Target device to place the model (e.g., 'cuda' or 'cpu') | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | tuple: Contains: | 
					
						
						|  | - vae (AutoencoderKLCausal3D): Loaded and configured VAE model | 
					
						
						|  | - vae_path (str): Actual path used to load the VAE | 
					
						
						|  | - spatial_compression_ratio (int): Spatial dimension compression factor | 
					
						
						|  | - time_compression_ratio (int): Temporal dimension compression factor | 
					
						
						|  |  | 
					
						
						|  | Raises: | 
					
						
						|  | ValueError: If vae_type does not follow the required 3D VAE format '???-*' | 
					
						
						|  | """ | 
					
						
						|  | if vae_path is None: | 
					
						
						|  | vae_path = VAE_PATH[vae_type] | 
					
						
						|  | vae_compress_spec, _, _ = vae_type.split("-") | 
					
						
						|  | length = len(vae_compress_spec) | 
					
						
						|  |  | 
					
						
						|  | if length == 3: | 
					
						
						|  | if logger is not None: | 
					
						
						|  | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") | 
					
						
						|  | config = AutoencoderKLCausal3D.load_config(vae_path) | 
					
						
						|  | if sample_size: | 
					
						
						|  | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) | 
					
						
						|  | else: | 
					
						
						|  | vae = AutoencoderKLCausal3D.from_config(config) | 
					
						
						|  | ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device) | 
					
						
						|  | if "state_dict" in ckpt: | 
					
						
						|  | ckpt = ckpt["state_dict"] | 
					
						
						|  | vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} | 
					
						
						|  | vae.load_state_dict(vae_ckpt) | 
					
						
						|  |  | 
					
						
						|  | spatial_compression_ratio = vae.config.spatial_compression_ratio | 
					
						
						|  | time_compression_ratio = vae.config.time_compression_ratio | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.") | 
					
						
						|  |  | 
					
						
						|  | if vae_precision is not None: | 
					
						
						|  | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) | 
					
						
						|  |  | 
					
						
						|  | vae.requires_grad_(False) | 
					
						
						|  |  | 
					
						
						|  | if logger is not None: | 
					
						
						|  | logger.info(f"VAE to dtype: {vae.dtype}") | 
					
						
						|  |  | 
					
						
						|  | if device is not None: | 
					
						
						|  | vae = vae.to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vae.eval() | 
					
						
						|  |  | 
					
						
						|  | return vae, vae_path, spatial_compression_ratio, time_compression_ratio | 
					
						
						|  |  |