File size: 1,868 Bytes
			
			| 42ae52a daf9c75 42ae52a daf9c75 42ae52a daf9c75 42ae52a daf9c75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | 
import torch
from diffusers import (
    AutoPipelineForText2Image,
    DiffusionPipeline,
    AutoencoderKL,
    FluxControlNetModel,
    FluxMultiControlNetModel,
)
from diffusers.schedulers import *
def load_flux():
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Models
    models = [
        {
            "repo_id": "black-forest-labs/FLUX.1-dev",
            "loader": "flux",
            "compute_type": torch.bfloat16,
        }
    ]
    for model in models:
        try:
            model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
                model['repo_id'],
                vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
                torch_dtype=model['compute_type'],
                safety_checker=None,
                variant="fp16"
            ).to(device)
        except:
            model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
                model['repo_id'],
                vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
                torch_dtype=model['compute_type'],
                safety_checker=None
            ).to(device)
        model["pipeline"].enable_model_cpu_offload()
    # VAE n Refiner
    flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
    # ControlNet
    controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
        "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
        torch_dtype=torch.bfloat16
    ).to(device)])
    return device, models, flux_vae, controlnet
device, models, flux_vae, controlnet = load_flux()
 |