|
import spaces |
|
import torch |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
AutoencoderKL, |
|
FluxControlNetModel, |
|
FluxMultiControlNetModel, |
|
ControlNetModel, |
|
AutoPipelineForText2Image, |
|
) |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
from huggingface_hub import hf_hub_download |
|
from transformers import CLIPFeatureExtractor |
|
from photomaker import FaceAnalysis2 |
|
|
|
|
|
|
|
def load_sd(): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
models = [ |
|
{ |
|
"repo_id": "black-forest-labs/FLUX.1-dev", |
|
"loader": "flux", |
|
"compute_type": torch.bfloat16, |
|
}, |
|
{ |
|
"repo_id": "SG161222/RealVisXL_V4.0", |
|
"loader": "xl", |
|
"compute_type": torch.float16, |
|
} |
|
] |
|
|
|
for model in models: |
|
try: |
|
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( |
|
model['repo_id'], |
|
torch_dtype = model['compute_type'], |
|
safety_checker = None, |
|
variant = "fp16" |
|
).to(device) |
|
model["pipeline"].enable_model_cpu_offload() |
|
except: |
|
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( |
|
model['repo_id'], |
|
torch_dtype = model['compute_type'], |
|
safety_checker = None |
|
).to(device) |
|
model["pipeline"].enable_model_cpu_offload() |
|
|
|
|
|
|
|
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device) |
|
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device) |
|
refiner.enable_model_cpu_offload() |
|
|
|
|
|
|
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device) |
|
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True) |
|
|
|
|
|
|
|
controlnet_models = [ |
|
{ |
|
"repo_id": "xinsir/controlnet-depth-sdxl-1.0", |
|
"name": "depth_xl", |
|
"layers": ["depth"], |
|
"loader": "xl", |
|
"compute_type": torch.float16, |
|
}, |
|
{ |
|
"repo_id": "xinsir/controlnet-canny-sdxl-1.0", |
|
"name": "canny_xl", |
|
"layers": ["canny"], |
|
"loader": "xl", |
|
"compute_type": torch.float16, |
|
}, |
|
{ |
|
"repo_id": "xinsir/controlnet-openpose-sdxl-1.0", |
|
"name": "openpose_xl", |
|
"layers": ["pose"], |
|
"loader": "xl", |
|
"compute_type": torch.float16, |
|
}, |
|
{ |
|
"repo_id": "xinsir/controlnet-scribble-sdxl-1.0", |
|
"name": "scribble_xl", |
|
"layers": ["scribble"], |
|
"loader": "xl", |
|
"compute_type": torch.float16, |
|
}, |
|
{ |
|
"repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", |
|
"name": "flux1_union_pro", |
|
"layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"], |
|
"loader": "flux-multi", |
|
"compute_type": torch.bfloat16, |
|
} |
|
] |
|
|
|
for controlnet in controlnet_models: |
|
if controlnet["loader"] == "xl": |
|
controlnet["controlnet"] = ControlNetModel.from_pretrained( |
|
controlnet["repo_id"], |
|
torch_dtype = controlnet['compute_type'] |
|
).to(device) |
|
elif controlnet["loader"] == "flux-multi": |
|
controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( |
|
controlnet["repo_id"], |
|
torch_dtype = controlnet['compute_type'] |
|
).to(device)]) |
|
|
|
|
|
|
|
|
|
face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) |
|
face_detector.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
|
|
|
|
|
photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model") |
|
|
|
return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt |
|
|
|
device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd() |
|
|