Spaces:
Running on Zero
Running on Zero
| """Abstract base class for all model types in LightDiffusion-Next. | |
| This module defines the contract that all model implementations must follow, | |
| enabling a clean, pluggable architecture where SD15, SDXL, and other models | |
| can be used interchangeably. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from typing import TYPE_CHECKING, Any, Callable, Optional | |
| import torch | |
| if TYPE_CHECKING: | |
| from src.Core.Context import Context | |
| class ModelCapabilities: | |
| """Describes what a model implementation can do. | |
| This allows the pipeline to adapt its behavior based on | |
| the loaded model's capabilities. | |
| """ | |
| # Resolution constraints | |
| min_resolution: int = 256 | |
| max_resolution: int = 2048 | |
| preferred_resolution: int = 512 | |
| requires_resolution_multiple: int = 64 | |
| # Feature support | |
| supports_hires_fix: bool = True | |
| supports_img2img: bool = True | |
| supports_inpainting: bool = False | |
| supports_controlnet: bool = False | |
| supports_lora: bool = True # LoRA compatibility | |
| # Performance hints | |
| supports_stable_fast: bool = True | |
| supports_deepcache: bool = True | |
| supports_tome: bool = True | |
| # Model-specific flags | |
| uses_dual_clip: bool = False # SDXL uses dual CLIP | |
| requires_size_conditioning: bool = False # SDXL needs size embeddings | |
| is_flux: bool = False | |
| is_flux2: bool = False | |
| def validate_resolution(self, width: int, height: int) -> tuple[int, int]: | |
| """Validate and adjust resolution to meet model requirements. | |
| Args: | |
| width: Requested width | |
| height: Requested height | |
| Returns: | |
| Adjusted (width, height) tuple | |
| """ | |
| # Maintain aspect ratio when clamping to max_resolution | |
| if width > self.max_resolution or height > self.max_resolution: | |
| scale = min(self.max_resolution / width, self.max_resolution / height) | |
| width = int(width * scale) | |
| height = int(height * scale) | |
| # Clamp to minimum | |
| width = max(self.min_resolution, width) | |
| height = max(self.min_resolution, height) | |
| # Round to required multiple | |
| width = (width // self.requires_resolution_multiple) * self.requires_resolution_multiple | |
| height = (height // self.requires_resolution_multiple) * self.requires_resolution_multiple | |
| return width, height | |
| class AbstractModel(ABC): | |
| """Abstract base class defining the contract for all model implementations. | |
| Every model type (SD15, SDXL, FLUX, etc.) must implement these methods | |
| to work with the modular pipeline. | |
| """ | |
| def __init__(self, model_path: str = None): | |
| """Initialize the model. | |
| Args: | |
| model_path: Optional path to the model checkpoint | |
| """ | |
| self.model_path = model_path | |
| self.model = None | |
| self.clip = None | |
| self.vae = None | |
| self._loaded = False | |
| self._capabilities: Optional[ModelCapabilities] = None | |
| def capabilities(self) -> ModelCapabilities: | |
| """Return the model's capabilities. | |
| Subclasses should override _create_capabilities() to customize. | |
| """ | |
| if self._capabilities is None: | |
| self._capabilities = self._create_capabilities() | |
| return self._capabilities | |
| def _create_capabilities(self) -> ModelCapabilities: | |
| """Create and return the capabilities for this model type. | |
| Returns: | |
| ModelCapabilities instance describing this model's features | |
| """ | |
| pass | |
| def is_loaded(self) -> bool: | |
| """Check if the model is currently loaded.""" | |
| return self._loaded | |
| def load(self, model_path: str = None) -> "AbstractModel": | |
| """Load the model from disk. | |
| Args: | |
| model_path: Optional override for the model path | |
| Returns: | |
| Self for method chaining | |
| """ | |
| pass | |
| def encode_prompt( | |
| self, | |
| prompt: str | list[str], | |
| negative_prompt: str | list[str] = "", | |
| clip_skip: int = -2, | |
| ) -> tuple[Any, Any]: | |
| """Encode text prompts into conditioning tensors. | |
| Args: | |
| prompt: Positive prompt(s) to encode | |
| negative_prompt: Negative prompt(s) to encode | |
| clip_skip: Number of CLIP layers to skip from the end | |
| Returns: | |
| Tuple of (positive_conditioning, negative_conditioning) | |
| """ | |
| pass | |
| def generate( | |
| self, | |
| ctx: "Context", | |
| positive: Any, | |
| negative: Any, | |
| latent_image: Optional[Any] = None, | |
| start_step: Optional[int] = None, | |
| last_step: Optional[int] = None, | |
| disable_noise: bool = False, | |
| callback: Optional[Callable] = None, | |
| ) -> dict: | |
| """Generate latents using the sampler. | |
| This is the core generation method that runs the diffusion process. | |
| Args: | |
| ctx: Pipeline context containing all generation parameters | |
| positive: Positive conditioning from encode_prompt | |
| negative: Negative conditioning from encode_prompt | |
| Returns: | |
| Dictionary containing 'samples' key with generated latents | |
| """ | |
| pass | |
| def decode(self, latents: torch.Tensor) -> torch.Tensor: | |
| """Decode latents to pixel space. | |
| Args: | |
| latents: Latent tensor to decode | |
| Returns: | |
| Decoded image tensor in [0, 1] range | |
| """ | |
| pass | |
| def set_vae_autotune(self, enabled: bool) -> None: | |
| """Update the loaded VAE autotune preference if the model exposes one.""" | |
| if self.vae is not None and hasattr(self.vae, "set_autotune_enabled"): | |
| self.vae.set_autotune_enabled(enabled) | |
| def apply_lora( | |
| self, | |
| lora_name: str, | |
| strength_model: float = 1.0, | |
| strength_clip: float = 1.0, | |
| ) -> "AbstractModel": | |
| """Apply a LoRA to the model. | |
| Default implementation attempts to use the standard LoRA loader. | |
| Subclasses can override for model-specific behavior. | |
| Args: | |
| lora_name: Name/path of the LoRA file | |
| strength_model: Strength to apply to the model | |
| strength_clip: Strength to apply to CLIP | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying LoRA") | |
| try: | |
| from src.Model import LoRas | |
| loader = LoRas.LoraLoader() | |
| result = loader.load_lora( | |
| lora_name=lora_name, | |
| strength_model=strength_model, | |
| strength_clip=strength_clip, | |
| model=self.model, | |
| clip=self.clip, | |
| ) | |
| self.model = result[0] | |
| self.clip = result[1] | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"Failed to apply LoRA {lora_name}: {e}") | |
| return self | |
| def apply_fp8(self) -> "AbstractModel": | |
| """Apply FP8 quantization to the diffusion model weights. | |
| Hardware-gated: only applies on supported GPUs (Ada Lovelace 8.9+, Hopper 9.0+). | |
| Reduces memory usage by ~50% vs FP16 with minimal quality impact. | |
| After casting weights to FP8, enables comfy_cast_weights on all affected | |
| modules so that forward() uses cast_bias_weight() to upcast FP8 weights | |
| to the input dtype at runtime, preventing dtype mismatch errors. | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying FP8") | |
| try: | |
| from src.Device import Device | |
| from src.cond.cast import CastWeightBiasOp | |
| if not Device.is_fp8_supported(): | |
| import logging | |
| logging.getLogger(__name__).info( | |
| "FP8 not supported on this GPU (requires compute capability 8.9+), skipping" | |
| ) | |
| return self | |
| inner = getattr(self.model, 'model', self.model) | |
| # Try common diffusion submodule names, otherwise fall back to top-level module | |
| diff_model = getattr(inner, 'diffusion_model', None) | |
| if diff_model is None: | |
| import torch.nn as nn | |
| if isinstance(inner, nn.Module): | |
| diff_model = inner | |
| import logging | |
| logging.getLogger(__name__).info( | |
| "No 'diffusion_model' submodule found; using top-level model for FP8 quantization" | |
| ) | |
| else: | |
| import logging | |
| logging.getLogger(__name__).warning("No diffusion_model found for FP8 quantization") | |
| return self | |
| converted = 0 | |
| cast_enabled = 0 | |
| for name, module in diff_model.named_modules(): | |
| # Quantize weight parameters to FP8 | |
| if hasattr(module, 'weight') and module.weight is not None: | |
| w = module.weight | |
| if w.dtype in (torch.float16, torch.bfloat16, torch.float32) and w.ndim >= 2: | |
| module.weight.data = Device.cast_to_fp8(w.data) | |
| converted += 1 | |
| # Enable runtime casting so forward() upcasts FP8→input dtype | |
| if isinstance(module, CastWeightBiasOp): | |
| module.comfy_cast_weights = True | |
| cast_enabled += 1 | |
| import logging | |
| logging.getLogger(__name__).info( | |
| f"FP8 quantization applied to {converted} weight tensors, " | |
| f"runtime casting enabled on {cast_enabled} modules" | |
| ) | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"FP8 quantization failed: {e}") | |
| return self | |
| def apply_nvfp4(self) -> "AbstractModel": | |
| """Apply NVFP4 (4-bit) quantization to the diffusion model weights. | |
| Reduces memory usage by ~75% vs FP16 with some quality impact. | |
| After quantizing weights to NVFP4, enables comfy_cast_weights on all affected | |
| modules so that forward() uses cast_bias_weight() to dequantize NVFP4 weights | |
| to the input dtype at runtime. | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying NVFP4") | |
| try: | |
| from src.cond.cast import CastWeightBiasOp | |
| from src.Utilities.Quantization import quantize_nvfp4 | |
| inner = getattr(self.model, 'model', self.model) | |
| diff_model = getattr(inner, 'diffusion_model', None) | |
| if diff_model is None: | |
| import torch.nn as nn | |
| if isinstance(inner, nn.Module): | |
| diff_model = inner | |
| else: | |
| import logging | |
| logging.getLogger(__name__).warning("No diffusion_model found for NVFP4 quantization") | |
| return self | |
| converted = 0 | |
| cast_enabled = 0 | |
| for name, module in diff_model.named_modules(): | |
| # Quantize weight parameters to NVFP4 | |
| if hasattr(module, 'weight') and module.weight is not None: | |
| w = module.weight | |
| if w.dtype in (torch.float16, torch.bfloat16, torch.float32) and w.ndim == 2 and w.numel() > 4096: | |
| from src.Utilities.Quantization import quantize_nvfp4, from_blocked | |
| q_weight, tensor_scale, blocked_scales = quantize_nvfp4(w.data) | |
| module.weight = torch.nn.Parameter(q_weight, requires_grad=False) | |
| module.quant_format = "nvfp4" | |
| # Pre-de-block scales to save compute during inference | |
| rows, cols = w.shape | |
| block_cols = (cols + 15) // 16 | |
| deblocked_scales = from_blocked(blocked_scales, rows, block_cols) | |
| import torch.nn as nn | |
| if isinstance(module, nn.Module): | |
| module.register_buffer("weight_scale_2", tensor_scale) | |
| module.register_buffer("weight_scale", deblocked_scales) | |
| else: | |
| module.weight_scale_2 = tensor_scale | |
| module.weight_scale = deblocked_scales | |
| module.original_shape = w.shape | |
| converted += 1 | |
| # Enable runtime casting so forward() dequantizes NVFP4→input dtype | |
| if isinstance(module, CastWeightBiasOp): | |
| module.comfy_cast_weights = True | |
| cast_enabled += 1 | |
| import logging | |
| logging.getLogger(__name__).info( | |
| f"NVFP4 quantization applied to {converted} weight tensors, " | |
| f"runtime casting enabled on {cast_enabled} modules" | |
| ) | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).exception(f"NVFP4 quantization failed: {e}") | |
| return self | |
| def apply_torch_compile(self, mode: str = "max-autotune-no-cudagraphs") -> "AbstractModel": | |
| """Apply torch.compile optimization to the model. | |
| Uses 'max-autotune-no-cudagraphs' by default to get autotuning benefits | |
| without CUDA graph fragility (which causes assertion errors with dynamic | |
| model state like LoRA patches and mixed dtypes). | |
| Args: | |
| mode: Compilation mode - 'max-autotune-no-cudagraphs' (recommended), | |
| 'max-autotune', 'default', or 'reduce-overhead' | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying torch.compile") | |
| try: | |
| from src.Device import Device | |
| if not hasattr(torch, 'compile'): | |
| import logging | |
| logging.getLogger(__name__).warning("torch.compile requires PyTorch 2.0+, skipping") | |
| return self | |
| Device.enable_torch_compile(True) | |
| inner = getattr(self.model, 'model', self.model) | |
| # Try to find a diffusion submodule; if missing, fall back to compiling the top-level module | |
| diff_model = getattr(inner, 'diffusion_model', None) | |
| if diff_model is None: | |
| import torch.nn as nn | |
| if isinstance(inner, nn.Module): | |
| # Compile the top-level module for models without a diffusion wrapper (Flux2, etc.) | |
| compiled = Device.compile_model(inner, mode=mode) | |
| if compiled is not inner: | |
| # If compile returns a Module we can safely replace the module. | |
| try: | |
| import torch.nn as _nn | |
| if isinstance(compiled, _nn.Module): | |
| if hasattr(self.model, 'model'): | |
| self.model.model = compiled | |
| else: | |
| self.model = compiled | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile applied to top-level model (mode={mode})") | |
| elif callable(compiled): | |
| # Preserve the original module instance but attach the compiled | |
| # callable to its forward method so attribute access (e.g. latent_format) | |
| # continues to work while runtime calls go through the compiled code. | |
| try: | |
| import types | |
| # attach compiled function to the inner module so forward calls use it | |
| setattr(inner, '_compiled_fn', compiled) | |
| def _compiled_forward(self, *args, **kwargs): | |
| return self._compiled_fn(*args, **kwargs) | |
| inner.forward = types.MethodType(_compiled_forward, inner) | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile returned callable; attached compiled forward to top-level module (mode={mode})") | |
| except Exception: | |
| import logging | |
| logging.getLogger(__name__).warning("Failed to attach compiled callable to module.forward; leaving original module intact") | |
| else: | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile returned unexpected type {type(compiled)}; leaving original model intact") | |
| except Exception: | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile returned a new object but could not reassign it; compiled object is available (mode={mode})") | |
| else: | |
| import logging | |
| logging.getLogger(__name__).warning("No diffusion_model found for torch.compile") | |
| else: | |
| compiled = Device.compile_model(diff_model, mode=mode) | |
| if compiled is not diff_model: | |
| # If compiled returned an nn.Module, replace the diffusion_model. | |
| import torch.nn as _nn | |
| if isinstance(compiled, _nn.Module): | |
| inner.diffusion_model = compiled | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile applied to diffusion model (mode={mode})") | |
| elif callable(compiled): | |
| # Attach compiled callable to the diffusion_model.forward so callers | |
| # (e.g. model.apply_model) continue to operate with the same | |
| # argument mapping while using compiled execution. | |
| try: | |
| import types | |
| if hasattr(inner, 'diffusion_model'): | |
| dm = inner.diffusion_model | |
| setattr(dm, '_compiled_fn', compiled) | |
| def _compiled_forward(self, *args, **kwargs): | |
| return self._compiled_fn(*args, **kwargs) | |
| dm.forward = types.MethodType(_compiled_forward, dm) | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile returned callable for diffusion_model; attached compiled forward (mode={mode})") | |
| else: | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile returned callable but no diffusion_model to attach to; compiled available (mode={mode})") | |
| except Exception: | |
| import logging | |
| logging.getLogger(__name__).warning("Failed to attach compiled callable to diffusion_model.forward") | |
| else: | |
| import logging | |
| logging.getLogger(__name__).info(f"torch.compile returned unexpected type {type(compiled)} for diffusion_model; leaving original module intact") | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"torch.compile optimization failed: {e}") | |
| return self | |
| def apply_stable_fast(self, enable_cuda_graph: bool = True) -> "AbstractModel": | |
| """Apply StableFast optimization to the model. | |
| Args: | |
| enable_cuda_graph: Whether to enable CUDA graphs | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying StableFast") | |
| if not self.capabilities.supports_stable_fast: | |
| import logging | |
| logging.getLogger(__name__).warning( | |
| f"Model does not support StableFast, skipping" | |
| ) | |
| return self | |
| try: | |
| from src.StableFast import StableFast | |
| applier = StableFast.ApplyStableFastUnet() | |
| result = applier.apply_stable_fast( | |
| enable_cuda_graph=enable_cuda_graph, | |
| model=self.model, | |
| ) | |
| self.model = result[0] | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"StableFast optimization failed: {e}") | |
| return self | |
| def apply_deepcache( | |
| self, | |
| cache_interval: int = 3, | |
| cache_depth: int = 2, | |
| start_step: int = 0, | |
| end_step: int = 1000, | |
| ) -> "AbstractModel": | |
| """Apply DeepCache optimization to the model. | |
| Args: | |
| cache_interval: Steps between cache updates | |
| cache_depth: U-Net depth for caching | |
| start_step: Start applying at this timestep | |
| end_step: Stop applying at this timestep | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying DeepCache") | |
| if not self.capabilities.supports_deepcache: | |
| import logging | |
| logging.getLogger(__name__).warning( | |
| f"Model does not support DeepCache, skipping" | |
| ) | |
| return self | |
| try: | |
| from src.WaveSpeed import deepcache_nodes | |
| deepcache = deepcache_nodes.ApplyDeepCacheOnModel() | |
| # DeepCache returns a tuple | |
| result = deepcache.patch( | |
| model=(self.model,), | |
| object_to_patch="diffusion_model", | |
| cache_interval=cache_interval, | |
| cache_depth=cache_depth, | |
| start_step=start_step, | |
| end_step=end_step, | |
| ) | |
| if isinstance(result, tuple) and len(result) > 0: | |
| self.model = result[0] | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"DeepCache optimization failed: {e}") | |
| return self | |
| def apply_hidiff(self, model_type: str = "auto") -> "AbstractModel": | |
| """Apply HiDiffusion MSW-MSA attention optimization. | |
| Args: | |
| model_type: Model type hint ('auto', 'sd15', 'sdxl') | |
| Returns: | |
| Self for method chaining | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before applying HiDiffusion") | |
| try: | |
| from src.hidiffusion import msw_msa_attention | |
| optimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple() | |
| result = optimizer.go(model_type=model_type, model=self.model) | |
| self.model = result[0] | |
| except Exception as e: | |
| import logging | |
| logging.getLogger(__name__).warning(f"HiDiffusion optimization failed: {e}") | |
| return self | |
| def unload(self) -> None: | |
| """Release model resources and free GPU memory.""" | |
| self.model = None | |
| self.clip = None | |
| self.vae = None | |
| self._loaded = False | |
| # Force garbage collection to release tensor references | |
| import gc | |
| gc.collect() | |
| # Attempt to free GPU memory | |
| try: | |
| from src.Device import Device | |
| Device.soft_empty_cache(force=True) | |
| except Exception: | |
| pass | |
| def __enter__(self) -> "AbstractModel": | |
| """Context manager entry - load the model.""" | |
| if not self._loaded: | |
| self.load() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| """Context manager exit - optionally unload the model.""" | |
| # By default we don't unload on context exit to support caching | |
| # Subclasses can override if they want different behavior | |
| pass | |