""" Model handler for WAN-VACE video generation """ import torch # ----------------------------------------------------------------------------- # XPU shim for CPU‑only environments # # Newer versions of `diffusers` attempt to call `torch.xpu.empty_cache()` for # Intel GPU support. If the installed PyTorch build does not include XPU # support (as is the case on CPU‑only environments), accessing `torch.xpu` # results in an AttributeError. To avoid this, we define a dummy `xpu` # namespace on the `torch` module when it is missing. This namespace # implements the minimal methods used by `diffusers` (`empty_cache`, # `is_available`, and `device_count`). # # Intel’s `intel-extension-for-pytorch` provides XPU support, but even when # installed, some CPU builds of PyTorch may not expose `torch.xpu`. This # shim ensures that the application runs regardless of whether XPU support is # present. # ----------------------------------------------------------------------------- if not hasattr(torch, "xpu"): class _DummyXPU: @staticmethod def empty_cache(): return None @staticmethod def manual_seed(_seed: int): return None @staticmethod def is_available(): return False @staticmethod def device_count(): return 0 @staticmethod def current_device(): return 0 @staticmethod def set_device(_idx: int): return None torch.xpu = _DummyXPU() # type: ignore import time from typing import Optional, Tuple, Any from transformers import UMT5EncoderModel from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel, GGUFQuantizationConfig from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from diffusers.utils import export_to_video from huggingface_hub import login import gradio as gr from config import MODEL_CONFIG, DEFAULT_PARAMS, HF_TOKEN import os from utils import create_temp_video_path, validate_generation_params, validate_prompt, format_generation_info class WanVACEModelHandler: """Handler for WAN-VACE model loading and video generation""" def __init__(self): self.pipe = None self.is_loaded = False self.loading_progress = 0 def login_hf(self) -> bool: """Login to Hugging Face""" try: login(token=HF_TOKEN) return True except Exception as e: print(f"Warning: Could not login to Hugging Face: {e}") return False def load_model(self, progress_callback=None) -> Tuple[bool, str]: """Load the WAN-VACE model components""" try: # Login to HF self.login_hf() if progress_callback: progress_callback(0.1, "Loading transformer model...") # Determine desired dtype for CPU/GPU execution. # Hugging Face Spaces often run on CPU, where bfloat16 may not be supported. # Allow the dtype to be configured via the WAN_DTYPE environment variable. # Supported values: "bfloat16" (default) or "float32". dtype_str = os.getenv("WAN_DTYPE", "bfloat16").lower() # Select compute dtype: use bfloat16 only if requested and available. # Fall back to float32 otherwise. compute_dtype = torch.bfloat16 if dtype_str == "bfloat16" else torch.float32 # Likewise for the torch dtype used when loading weights. torch_dtype = compute_dtype # Load transformer transformer = WanVACETransformer3DModel.from_single_file( MODEL_CONFIG["transformer_path"], quantization_config=GGUFQuantizationConfig(compute_dtype=compute_dtype), torch_dtype=torch_dtype, ) if progress_callback: progress_callback(0.4, "Loading text encoder...") # Load text encoder text_encoder = UMT5EncoderModel.from_pretrained( MODEL_CONFIG["text_encoder_path"], gguf_file=MODEL_CONFIG["text_encoder_file"], torch_dtype=torch_dtype, ) if progress_callback: progress_callback(0.7, "Loading VAE...") # Load VAE vae = AutoencoderKLWan.from_pretrained( MODEL_CONFIG["vae_path"], subfolder="vae", torch_dtype=torch.float32 ) if progress_callback: progress_callback(0.9, "Assembling pipeline...") # Create pipeline self.pipe = WanVACEPipeline.from_pretrained( MODEL_CONFIG["pipeline_path"], transformer=transformer, text_encoder=text_encoder, vae=vae, torch_dtype=torch_dtype ) # Configure scheduler flow_shift = DEFAULT_PARAMS["flow_shift"] self.pipe.scheduler = UniPCMultistepScheduler.from_config( self.pipe.scheduler.config, flow_shift=flow_shift ) # Enable optimizations self.pipe.enable_model_cpu_offload() self.pipe.vae.enable_tiling() self.is_loaded = True if progress_callback: progress_callback(1.0, "Model loaded successfully!") return True, "Model loaded successfully!" except Exception as e: error_msg = f"Error loading model: {str(e)}" if progress_callback: progress_callback(0, error_msg) return False, error_msg def generate_video( self, prompt: str, negative_prompt: str = "", width: int = DEFAULT_PARAMS["width"], height: int = DEFAULT_PARAMS["height"], num_frames: int = DEFAULT_PARAMS["num_frames"], num_inference_steps: int = DEFAULT_PARAMS["num_inference_steps"], guidance_scale: float = DEFAULT_PARAMS["guidance_scale"], seed: Optional[int] = None, progress_callback=None ) -> Tuple[bool, str, str, str]: """ Generate video from text prompt Returns: (success, video_path, error_message, generation_info) """ if not self.is_loaded: return False, "", "Model not loaded. Please load the model first.", "" # Validate inputs prompt_valid, prompt_error = validate_prompt(prompt) if not prompt_valid: return False, "", prompt_error or "Invalid prompt", "" params_valid, params_error = validate_generation_params( width, height, num_frames, num_inference_steps, guidance_scale ) if not params_valid: return False, "", params_error or "Invalid parameters", "" try: if progress_callback: progress_callback(0.1, "Preparing generation...") # Check if pipeline is loaded if self.pipe is None: return False, "", "Pipeline not initialized. Please load the model first.", "" # Set up generator with seed generator = torch.Generator() if seed is not None: generator.manual_seed(seed) else: generator.manual_seed(0) # Default seed if progress_callback: progress_callback(0.2, "Starting video generation...") start_time = time.time() # Generate video output = self.pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, width=width, height=height, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, conditioning_scale=DEFAULT_PARAMS["conditioning_scale"], generator=generator, ).frames[0] if progress_callback: progress_callback(0.8, "Exporting video...") # Export to video file output_path = create_temp_video_path() export_to_video(output, output_path, fps=DEFAULT_PARAMS["fps"]) generation_time = time.time() - start_time if progress_callback: progress_callback(1.0, "Video generation complete!") # Format generation info gen_info = format_generation_info( prompt, negative_prompt, width, height, num_frames, num_inference_steps, guidance_scale, generation_time ) return True, output_path, "", gen_info except Exception as e: error_msg = f"Error during video generation: {str(e)}" if progress_callback: progress_callback(0, error_msg) return False, "", error_msg, "" # Global model handler instance model_handler = WanVACEModelHandler()