mantrakp
Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes
daf9c75
import torch | |
from diffusers import ( | |
AutoPipelineForText2Image, | |
DiffusionPipeline, | |
AutoencoderKL, | |
FluxControlNetModel, | |
FluxMultiControlNetModel, | |
) | |
from diffusers.schedulers import * | |
def load_flux(): | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Models | |
models = [ | |
{ | |
"repo_id": "black-forest-labs/FLUX.1-dev", | |
"loader": "flux", | |
"compute_type": torch.bfloat16, | |
} | |
] | |
for model in models: | |
try: | |
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( | |
model['repo_id'], | |
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), | |
torch_dtype=model['compute_type'], | |
safety_checker=None, | |
variant="fp16" | |
).to(device) | |
except: | |
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( | |
model['repo_id'], | |
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), | |
torch_dtype=model['compute_type'], | |
safety_checker=None | |
).to(device) | |
model["pipeline"].enable_model_cpu_offload() | |
# VAE n Refiner | |
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) | |
# ControlNet | |
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( | |
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", | |
torch_dtype=torch.bfloat16 | |
).to(device)]) | |
return device, models, flux_vae, controlnet | |
device, models, flux_vae, controlnet = load_flux() | |