File size: 4,666 Bytes
5a59c13
9b84ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import spaces
import torch
from diffusers import (
    DiffusionPipeline,
    AutoencoderKL,
    FluxControlNetModel,
    FluxMultiControlNetModel,
    ControlNetModel,
    AutoPipelineForText2Image,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import hf_hub_download
from transformers import CLIPFeatureExtractor
from photomaker import FaceAnalysis2


# Initialize System
def load_sd():
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Models
    models = [
        {
            "repo_id": "black-forest-labs/FLUX.1-dev",
            "loader": "flux",
            "compute_type": torch.bfloat16,
        },
        {
            "repo_id": "SG161222/RealVisXL_V4.0",
            "loader": "xl",
            "compute_type": torch.float16,
        }
    ]

    for model in models:
        try:
            model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
                model['repo_id'],
                torch_dtype = model['compute_type'],
                safety_checker = None,
                variant = "fp16"
            ).to(device)
            model["pipeline"].enable_model_cpu_offload()
        except:
            model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
                model['repo_id'],
                torch_dtype = model['compute_type'],
                safety_checker = None
            ).to(device)
            model["pipeline"].enable_model_cpu_offload() 


    # VAE n Refiner
    sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
    refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
    refiner.enable_model_cpu_offload()


    # Safety Checker
    safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
    feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)


    # Controlnets
    controlnet_models = [
        {
            "repo_id": "xinsir/controlnet-depth-sdxl-1.0",
            "name": "depth_xl",
            "layers": ["depth"],
            "loader": "xl",
            "compute_type": torch.float16,
        },
        {
            "repo_id": "xinsir/controlnet-canny-sdxl-1.0",
            "name": "canny_xl",
            "layers": ["canny"],
            "loader": "xl",
            "compute_type": torch.float16,
        },
        {
            "repo_id": "xinsir/controlnet-openpose-sdxl-1.0",
            "name": "openpose_xl",
            "layers": ["pose"],
            "loader": "xl",
            "compute_type": torch.float16,
        },
        {
            "repo_id": "xinsir/controlnet-scribble-sdxl-1.0",
            "name": "scribble_xl",
            "layers": ["scribble"],
            "loader": "xl",
            "compute_type": torch.float16,
        },
        {
            "repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
            "name": "flux1_union_pro",
            "layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"],
            "loader": "flux-multi",
            "compute_type": torch.bfloat16,
        }
    ]

    for controlnet in controlnet_models:
        if controlnet["loader"] == "xl":
            controlnet["controlnet"] = ControlNetModel.from_pretrained(
                controlnet["repo_id"],
                torch_dtype = controlnet['compute_type']
            ).to(device)
        elif controlnet["loader"] == "flux-multi":
            controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
                controlnet["repo_id"],
                torch_dtype = controlnet['compute_type']
            ).to(device)])
        #TODO: Add support for flux only controlnet


    # Face Detection (for PhotoMaker)
    face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
    face_detector.prepare(ctx_id=0, det_size=(640, 640))


    # PhotoMaker V2 (for SDXL only)
    photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model")

    return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt

device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd()