Spaces:
Runtime error
Runtime error
| """ | |
| Model loading and initialization for Pixagram AI Pixel Art Generator | |
| FIXED VERSION - Uses correct InstantID pipeline and Compel encoder | |
| """ | |
| import torch | |
| import time | |
| import os | |
| from diffusers import ( | |
| ControlNetModel, | |
| AutoencoderKL, | |
| LCMScheduler | |
| ) | |
| from transformers import ( | |
| CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection | |
| ) | |
| from insightface.app import FaceAnalysis | |
| from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| # --- START FIX: Import correct pipeline and Compel --- | |
| from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline | |
| from compel import Compel, ReturnedEmbeddingsType | |
| # --- END FIX --- | |
| from config import ( | |
| device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN, | |
| FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG | |
| ) | |
| # (We keep download_model_with_retry, load_face_analysis, load_depth_detector, | |
| # load_openpose_detector, and load_mediapipe_face_detector as they were) | |
| # ... (Keep all original functions from line 25 down to line 180) ... | |
| def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs): | |
| """Download model with retry logic and proper token handling.""" | |
| if max_retries is None: | |
| max_retries = DOWNLOAD_CONFIG['max_retries'] | |
| # Ensure token is passed if available | |
| if HUGGINGFACE_TOKEN and "token" not in kwargs: | |
| kwargs["token"] = HUGGINGFACE_TOKEN | |
| for attempt in range(max_retries): | |
| try: | |
| print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...") | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| **kwargs | |
| ) | |
| except Exception as e: | |
| print(f" [WARNING] Download attempt {attempt + 1} failed: {e}") | |
| if attempt < max_retries - 1: | |
| print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...") | |
| time.sleep(DOWNLOAD_CONFIG['retry_delay']) | |
| else: | |
| print(f" [ERROR] Failed to download {filename} after {max_retries} attempts") | |
| raise | |
| return None | |
| def load_face_analysis(): | |
| """ | |
| Load face analysis model with proper model downloading from HuggingFace. | |
| Downloads from DIAMONIK7777/antelopev2 which has the correct model structure. | |
| """ | |
| print("Loading face analysis model...") | |
| try: | |
| antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2") | |
| # --- FIX: Load InsightFace on CPU to save VRAM --- | |
| face_app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider']) | |
| face_app.prepare(ctx_id=0, det_size=(640, 640)) | |
| print(" [OK] Face analysis loaded (on CPU)") | |
| return face_app, True | |
| except Exception as e: | |
| print(f" [ERROR] Face detection not available: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| def load_depth_detector(): | |
| """ | |
| Load depth detector with fallback hierarchy: Leres → Zoe → Midas. | |
| Returns (detector, detector_type, success). | |
| """ | |
| print("Loading depth detector with fallback hierarchy...") | |
| # Try LeresDetector first (best quality) | |
| try: | |
| print(" Attempting LeresDetector (highest quality)...") | |
| # --- FIX: Load on CPU --- | |
| leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators") | |
| # leres_depth.to(device) # Removed | |
| print(" [OK] LeresDetector loaded successfully (on CPU)") | |
| return leres_depth, 'leres', True | |
| except Exception as e: | |
| print(f" [INFO] LeresDetector not available: {e}") | |
| # Fallback to ZoeDetector | |
| try: | |
| print(" Attempting ZoeDetector (fallback #1)...") | |
| # --- FIX: Load on CPU --- | |
| zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
| # zoe_depth.to(device) # Removed | |
| print(" [OK] ZoeDetector loaded successfully (on CPU)") | |
| return zoe_depth, 'zoe', True | |
| except Exception as e: | |
| print(f" [INFO] ZoeDetector not available: {e}") | |
| # Final fallback to MidasDetector | |
| try: | |
| print(" Attempting MidasDetector (fallback #2)...") | |
| # --- FIX: Load on CPU --- | |
| midas_depth = MidasDetector.from_pretrained("lllyasviel/Annotators") | |
| # midas_depth.to(device) # Removed | |
| print(" [OK] MidasDetector loaded successfully (on CPU)") | |
| return midas_depth, 'midas', True | |
| except Exception as e: | |
| print(f" [WARNING] MidasDetector not available: {e}") | |
| print(" [ERROR] No depth detector available") | |
| return None, None, False | |
| # --- NEW FUNCTION --- | |
| def load_openpose_detector(): | |
| """Load OpenPose detector.""" | |
| print("Loading OpenPose detector...") | |
| try: | |
| # --- FIX: Load on CPU --- | |
| openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") | |
| # openpose.to(device) # Removed | |
| print(" [OK] OpenPose loaded successfully (on CPU)") | |
| return openpose, True | |
| except Exception as e: | |
| print(f" [WARNING] OpenPose not available: {e}") | |
| return None, False | |
| # --- END NEW FUNCTION --- | |
| # --- NEW FUNCTION --- | |
| def load_mediapipe_face_detector(): | |
| """Load MediapipeFaceDetector for advanced face detection.""" | |
| print("Loading MediapipeFaceDetector...") | |
| try: | |
| face_detector = MediapipeFaceDetector() | |
| print(" [OK] MediapipeFaceDetector loaded successfully") | |
| return face_detector, True | |
| except Exception as e: | |
| print(f" [WARNING] MediapipeFaceDetector not available: {e}") | |
| return None, False | |
| # --- END NEW FUNCTION --- | |
| def load_controlnets(): | |
| """Load ControlNet models.""" | |
| print("Loading ControlNet Zoe Depth model...") | |
| # --- FIX: Load core models on GPU --- | |
| controlnet_depth = ControlNetModel.from_pretrained( | |
| "xinsir/controlnet-depth-sdxl-1.0", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] ControlNet Depth loaded (on GPU)") | |
| # --- NEW: Load OpenPose ControlNet --- | |
| print("Loading ControlNet OpenPose model...") | |
| try: | |
| # --- FIX: Load core models on GPU --- | |
| controlnet_openpose = ControlNetModel.from_pretrained( | |
| "xinsir/controlnet-openpose-sdxl-1.0", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] ControlNet OpenPose loaded (on GPU)") | |
| except Exception as e: | |
| print(f" [WARNING] ControlNet OpenPose not available: {e}") | |
| controlnet_openpose = None | |
| # --- END NEW --- | |
| print("Loading InstantID ControlNet...") | |
| try: | |
| # --- FIX: Load core models on GPU --- | |
| controlnet_instantid = ControlNetModel.from_pretrained( | |
| "InstantX/InstantID", | |
| subfolder="ControlNetModel", | |
| torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] InstantID ControlNet loaded successfully (on GPU)") | |
| # Return all three models | |
| return controlnet_depth, controlnet_instantid, controlnet_openpose, True | |
| except Exception as e: | |
| print(f" [WARNING] InstantID ControlNet not available: {e}") | |
| # Return models, indicating InstantID failure | |
| return controlnet_depth, None, controlnet_openpose, False | |
| # --- START: REMOVED load_image_encoder --- | |
| # (The new pipeline handles this internally) | |
| # --- END: REMOVED load_image_encoder --- | |
| def load_sdxl_pipeline(controlnets): | |
| """Load SDXL checkpoint from HuggingFace Hub.""" | |
| print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...") | |
| # --- START FIX: Load base text models for Compel (from previous fix) --- | |
| print(" Loading base tokenizers and text encoders...") | |
| BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" | |
| tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer") | |
| tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype | |
| ).to(device) | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype | |
| ).to(device) | |
| print(" [OK] Base text/token models loaded") | |
| # --- END FIX --- | |
| try: | |
| model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model") | |
| # --- START FIX: Load the CORRECT pipeline --- | |
| pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file( | |
| model_path, | |
| controlnet=controlnets, | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| # Pass components | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| ).to(device) | |
| # --- END FIX --- | |
| print(" [OK] Custom checkpoint loaded successfully (VAE bundled)") | |
| return pipe, True | |
| except Exception as e: | |
| print(f" [WARNING] Could not load custom checkpoint: {e}") | |
| print(" Using default SDXL base model") | |
| # --- START FIX: Fallback to the CORRECT pipeline --- | |
| pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| controlnet=controlnets, | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| # Pass components | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| ).to(device) | |
| # --- END FIX --- | |
| return pipe, False | |
| def load_loras(pipe): | |
| """Load all LORAs from HuggingFace Hub.""" | |
| print("Loading all LORAs from HuggingFace Hub...") | |
| loaded_loras = {} | |
| lora_files = { | |
| "retroart": MODEL_FILES.get("lora_retroart"), | |
| "vga": MODEL_FILES.get("lora_vga"), | |
| "lucasart": MODEL_FILES.get("lora_lucasart") | |
| } | |
| for adapter_name, filename in lora_files.items(): | |
| if not filename: | |
| print(f" [INFO] No file specified for LORA '{adapter_name}', skipping.") | |
| loaded_loras[adapter_name] = False | |
| continue | |
| try: | |
| lora_path = download_model_with_retry(MODEL_REPO, filename, repo_type="model") | |
| pipe.load_lora_weights(lora_path, adapter_name=adapter_name) | |
| print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'") | |
| loaded_loras[adapter_name] = True | |
| except Exception as e: | |
| print(f" [WARNING] Could not load LORA {filename}: {e}") | |
| loaded_loras[adapter_name] = False | |
| success = any(loaded_loras.values()) | |
| if not success: | |
| print(" [WARNING] No LORAs were loaded successfully.") | |
| return loaded_loras, success | |
| # --- START FIX: Replace setup_ip_adapter --- | |
| def setup_ip_adapter(pipe): | |
| """ | |
| Setup IP-Adapter for InstantID face embeddings using the pipeline's method. | |
| """ | |
| print("Setting up IP-Adapter for InstantID face embeddings...") | |
| try: | |
| # Download InstantID weights | |
| ip_adapter_path = download_model_with_retry( | |
| "InstantX/InstantID", | |
| "ip-adapter.bin", | |
| repo_type="model" | |
| ) | |
| # Use the pipeline's built-in loader | |
| pipe.load_ip_adapter_instantid(ip_adapter_path) | |
| print(" [OK] IP-Adapter fully loaded via pipeline") | |
| return None, True # We don't need to return a model | |
| except Exception as e: | |
| print(f" [ERROR] Could not setup IP-Adapter: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, False | |
| # --- END FIX --- | |
| # --- START FIX: Replace setup_cappella with setup_compel --- | |
| def setup_compel(pipe): | |
| """Setup Compel for robust prompt encoding.""" | |
| print("Setting up Compel (prompt encoder)...") | |
| try: | |
| compel = Compel( | |
| tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
| text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
| returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
| requires_pooled=[False, True] | |
| ) | |
| print(" [OK] Compel loaded successfully.") | |
| return compel, True | |
| except Exception as e: | |
| print(f" [WARNING] Compel not available: {e}") | |
| return None, False | |
| # --- END FIX --- | |
| def setup_scheduler(pipe): | |
| """Setup LCM scheduler.""" | |
| print("Setting up LCM scheduler...") | |
| pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
| print(" [OK] LCM scheduler configured") | |
| def optimize_pipeline(pipe): | |
| """Apply optimizations to pipeline.""" | |
| if device == "cuda": | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| print(" [OK] xformers enabled") | |
| except Exception as e: | |
| print(f" [INFO] xformers not available: {e}") | |
| def load_caption_model(): | |
| """ | |
| Load caption model with proper error handling. | |
| Tries multiple models in order of quality. | |
| """ | |
| print("Loading caption model...") | |
| # Try GIT-Large first (good balance of quality and compatibility) | |
| try: | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| print(" Attempting GIT-Large (recommended)...") | |
| caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
| caption_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/git-large-coco", | |
| torch_dtype=dtype | |
| ) | |
| print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)") | |
| return caption_processor, caption_model, True, 'git' | |
| except Exception as e1: | |
| print(f" [INFO] GIT-Large not available: {e1}") | |
| # Try BLIP base as fallback | |
| try: | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| print(" Attempting BLIP base (fallback)...") | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", | |
| torch_dtype=dtype | |
| ) | |
| print(" [OK] BLIP base model loaded (standard captions, on CPU)") | |
| return caption_processor, caption_model, True, 'blip' | |
| except Exception as e2: | |
| print(f" [WARNING] Caption models not available: {e2}") | |
| print(" Caption generation will be disabled") | |
| return None, None, False, 'none' | |
| def set_clip_skip(pipe): | |
| """Set CLIP skip value.""" | |
| if hasattr(pipe, 'text_encoder'): | |
| print(f" [OK] CLIP skip set to {CLIP_SKIP}") | |
| print("[OK] Model loading functions ready") |