import spaces import torch from diffusers import ( DiffusionPipeline, AutoencoderKL, FluxControlNetModel, FluxMultiControlNetModel, ControlNetModel, AutoPipelineForText2Image, ) from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import hf_hub_download from transformers import CLIPFeatureExtractor from photomaker import FaceAnalysis2 # Initialize System def load_sd(): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cuda" if torch.cuda.is_available() else "cpu" # Models models = [ { "repo_id": "black-forest-labs/FLUX.1-dev", "loader": "flux", "compute_type": torch.bfloat16, }, { "repo_id": "SG161222/RealVisXL_V4.0", "loader": "xl", "compute_type": torch.float16, } ] for model in models: try: model["pipeline"] = AutoPipelineForText2Image.from_pretrained( model['repo_id'], torch_dtype = model['compute_type'], safety_checker = None, variant = "fp16" ).to(device) model["pipeline"].enable_model_cpu_offload() except: model["pipeline"] = AutoPipelineForText2Image.from_pretrained( model['repo_id'], torch_dtype = model['compute_type'], safety_checker = None ).to(device) model["pipeline"].enable_model_cpu_offload() # VAE n Refiner sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device) refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device) refiner.enable_model_cpu_offload() # Safety Checker safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device) feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True) # Controlnets controlnet_models = [ { "repo_id": "xinsir/controlnet-depth-sdxl-1.0", "name": "depth_xl", "layers": ["depth"], "loader": "xl", "compute_type": torch.float16, }, { "repo_id": "xinsir/controlnet-canny-sdxl-1.0", "name": "canny_xl", "layers": ["canny"], "loader": "xl", "compute_type": torch.float16, }, { "repo_id": "xinsir/controlnet-openpose-sdxl-1.0", "name": "openpose_xl", "layers": ["pose"], "loader": "xl", "compute_type": torch.float16, }, { "repo_id": "xinsir/controlnet-scribble-sdxl-1.0", "name": "scribble_xl", "layers": ["scribble"], "loader": "xl", "compute_type": torch.float16, }, { "repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "name": "flux1_union_pro", "layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"], "loader": "flux-multi", "compute_type": torch.bfloat16, } ] for controlnet in controlnet_models: if controlnet["loader"] == "xl": controlnet["controlnet"] = ControlNetModel.from_pretrained( controlnet["repo_id"], torch_dtype = controlnet['compute_type'] ).to(device) elif controlnet["loader"] == "flux-multi": controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( controlnet["repo_id"], torch_dtype = controlnet['compute_type'] ).to(device)]) #TODO: Add support for flux only controlnet # Face Detection (for PhotoMaker) face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) face_detector.prepare(ctx_id=0, det_size=(640, 640)) # PhotoMaker V2 (for SDXL only) photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model") return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd()