aai / old2 /old /src /tasks /images /init_sys.py
mantrakp
Refactor flux_helpers.py to enable or disable Vae
37112ef
raw
history blame
4.67 kB
import spaces
import torch
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
FluxControlNetModel,
FluxMultiControlNetModel,
ControlNetModel,
AutoPipelineForText2Image,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import hf_hub_download
from transformers import CLIPFeatureExtractor
from photomaker import FaceAnalysis2
# Initialize System
def load_sd():
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Models
models = [
{
"repo_id": "black-forest-labs/FLUX.1-dev",
"loader": "flux",
"compute_type": torch.bfloat16,
},
{
"repo_id": "SG161222/RealVisXL_V4.0",
"loader": "xl",
"compute_type": torch.float16,
}
]
for model in models:
try:
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
model['repo_id'],
torch_dtype = model['compute_type'],
safety_checker = None,
variant = "fp16"
).to(device)
model["pipeline"].enable_model_cpu_offload()
except:
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
model['repo_id'],
torch_dtype = model['compute_type'],
safety_checker = None
).to(device)
model["pipeline"].enable_model_cpu_offload()
# VAE n Refiner
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
refiner.enable_model_cpu_offload()
# Safety Checker
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
# Controlnets
controlnet_models = [
{
"repo_id": "xinsir/controlnet-depth-sdxl-1.0",
"name": "depth_xl",
"layers": ["depth"],
"loader": "xl",
"compute_type": torch.float16,
},
{
"repo_id": "xinsir/controlnet-canny-sdxl-1.0",
"name": "canny_xl",
"layers": ["canny"],
"loader": "xl",
"compute_type": torch.float16,
},
{
"repo_id": "xinsir/controlnet-openpose-sdxl-1.0",
"name": "openpose_xl",
"layers": ["pose"],
"loader": "xl",
"compute_type": torch.float16,
},
{
"repo_id": "xinsir/controlnet-scribble-sdxl-1.0",
"name": "scribble_xl",
"layers": ["scribble"],
"loader": "xl",
"compute_type": torch.float16,
},
{
"repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"name": "flux1_union_pro",
"layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"],
"loader": "flux-multi",
"compute_type": torch.bfloat16,
}
]
for controlnet in controlnet_models:
if controlnet["loader"] == "xl":
controlnet["controlnet"] = ControlNetModel.from_pretrained(
controlnet["repo_id"],
torch_dtype = controlnet['compute_type']
).to(device)
elif controlnet["loader"] == "flux-multi":
controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
controlnet["repo_id"],
torch_dtype = controlnet['compute_type']
).to(device)])
#TODO: Add support for flux only controlnet
# Face Detection (for PhotoMaker)
face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
face_detector.prepare(ctx_id=0, det_size=(640, 640))
# PhotoMaker V2 (for SDXL only)
photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model")
return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt
device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd()