aai / modules /pipelines /flux_pipelines.py
mantrakp
Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes
daf9c75
raw
history blame
1.87 kB
import torch
from diffusers import (
AutoPipelineForText2Image,
DiffusionPipeline,
AutoencoderKL,
FluxControlNetModel,
FluxMultiControlNetModel,
)
from diffusers.schedulers import *
def load_flux():
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Models
models = [
{
"repo_id": "black-forest-labs/FLUX.1-dev",
"loader": "flux",
"compute_type": torch.bfloat16,
}
]
for model in models:
try:
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
model['repo_id'],
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
torch_dtype=model['compute_type'],
safety_checker=None,
variant="fp16"
).to(device)
except:
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
model['repo_id'],
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
torch_dtype=model['compute_type'],
safety_checker=None
).to(device)
model["pipeline"].enable_model_cpu_offload()
# VAE n Refiner
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
# ControlNet
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
torch_dtype=torch.bfloat16
).to(device)])
return device, models, flux_vae, controlnet
device, models, flux_vae, controlnet = load_flux()