import torch from diffusers import AutoencoderKL def get_vae(version, file_path=None, fp16=False): """Load VAE from file or default hf repo. fp16 only works from hf""" vae = None dtype = torch.float16 if fp16 else torch.float32 if version == "v1" and file_path: vae = AutoencoderKL.from_single_file( file_path, image_size=512, ) elif version == "v1": vae = AutoencoderKL.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=dtype, ) elif version == "v2" and file_path: vae = AutoencoderKL.from_single_file( file_path, image_size=768, ) elif version == "v2": vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-2-1", subfolder="vae", torch_dtype=dtype, ) elif version == "xl" and file_path: vae = AutoencoderKL.from_single_file( file_path, image_size=1024 ) elif version == "xl" and fp16: vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) elif version == "xl": vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae" ) else: input("Invalid VAE version. Press any key to exit") exit(1) return vae