aai / modules /pipelines /common_pipelines.py
mantrakp
Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes
daf9c75
raw
history blame
631 Bytes
import torch
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
)
from diffusers.schedulers import *
def load_common():
device = "cuda" if torch.cuda.is_available() else "cpu"
# VAE n Refiner
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
refiner.enable_model_cpu_offload()
return refiner, sdxl_vae
refiner, sdxl_vae = load_common()