Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Optimized device and memory management for LightDiffusion-Next.
Performance optimizations from ComfyUI:
- Async CUDA streams for weight offloading
- Pinned memory for faster CPU-GPU transfers
- cuDNN benchmarking
- FP16 accumulation
"""
import logging
import platform
import sys
from enum import Enum
from typing import Optional, Union, Tuple
import psutil
import torch
# Enable TF32 on supported hardware for faster matrix ops
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except:
pass
# Enable cuDNN benchmarking for optimal convolution algorithms
try:
torch.backends.cudnn.benchmark = True
except:
pass
# === SDPA Backend Priority (from ComfyUI for optimal attention on Windows) ===
# Set Flash Attention > Efficient > Math priority
SDPA_PRIORITY_SET = False
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
# Add cuDNN attention if available (newest)
if hasattr(SDPBackend, 'CUDNN_ATTENTION'):
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
SDPA_PRIORITY_SET = True
logging.info(f"SDPA backend priority set: {[b.name for b in SDPA_BACKEND_PRIORITY]}")
except (ModuleNotFoundError, TypeError, AttributeError) as e:
logging.debug(f"Could not set SDPA backend priority: {e}")
def get_sdpa_context():
"""Get context manager for SDPA backend priority."""
if SDPA_PRIORITY_SET:
from torch.nn.attention import sdpa_kernel
return sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True)
else:
import contextlib
return contextlib.nullcontext()
class VRAMState(Enum):
DISABLED = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
SHARED = 5
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Global state
vram_state = VRAMState.NORMAL_VRAM
cpu_state = CPUState.GPU
directml_enabled = False
xpu_available = False
DISABLE_SMART_MEMORY = False
FORCE_FP32 = False
FORCE_FP16 = False
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 if WINDOWS else 400 * 1024 * 1024
# Async offloading with CUDA streams (from ComfyUI)
NUM_STREAMS = 2 # Set to 2 for async offloading on Nvidia/AMD
STREAMS = {}
stream_counters = {}
# Pinned memory management (from ComfyUI)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1 # Will be set during initialization
# Detect hardware
try:
xpu_available = torch.xpu.is_available()
except:
pass
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
except:
pass
# Library availability
XFORMERS_IS_AVAILABLE = False
XFORMERS_ENABLED_VAE = True
SAGEATTENTION_IS_AVAILABLE = False
SAGEATTENTION_ENABLED_VAE = True
SPARGEATTN_IS_AVAILABLE = False
SPARGEATTN_ENABLED_VAE = True
ENABLE_PYTORCH_ATTENTION = False
VAE_DTYPE = torch.float32
try:
import xformers.ops
XFORMERS_IS_AVAILABLE = getattr(xformers, '_has_cpp_library', True)
v = getattr(xformers.version, '__version__', '')
if v.startswith("0.0.18"):
XFORMERS_ENABLED_VAE = False
logging.warning("xformers 0.0.18 has black image bugs")
except:
pass
try:
import sageattention
SAGEATTENTION_IS_AVAILABLE = True
except:
pass
try:
import spas_sage_attn
SPARGEATTN_IS_AVAILABLE = True
except:
pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
# === Async CUDA Stream Management (from ComfyUI for faster offloading) ===
def get_offload_stream(device: torch.device):
"""Get a CUDA stream for async weight offloading."""
global STREAMS, stream_counters, NUM_STREAMS
if NUM_STREAMS < 1:
return None
if not torch.cuda.is_available():
return None
device_idx = device.index if device.index is not None else 0
if device_idx not in STREAMS:
STREAMS[device_idx] = [torch.cuda.Stream(device=device) for _ in range(NUM_STREAMS)]
stream_counters[device_idx] = 0
stream_idx = stream_counters[device_idx] % NUM_STREAMS
stream_counters[device_idx] += 1
return STREAMS[device_idx][stream_idx]
def sync_stream(device: torch.device, stream):
"""Synchronize a CUDA stream."""
if stream is not None and torch.cuda.is_available():
stream.synchronize()
def sync_all_streams(device: torch.device = None):
"""Synchronize all streams for a device."""
global STREAMS
if device is None:
for dev_streams in STREAMS.values():
for stream in dev_streams:
stream.synchronize()
else:
device_idx = device.index if device.index is not None else 0
if device_idx in STREAMS:
for stream in STREAMS[device_idx]:
stream.synchronize()
# === Pinned Memory Management (from ComfyUI for faster CPU<->GPU transfers) ===
def init_pinned_memory():
"""Initialize pinned memory subsystem."""
global MAX_PINNED_MEMORY
try:
# Use up to 25% of system RAM for pinned memory (capped at 8GB)
total_ram = psutil.virtual_memory().total
MAX_PINNED_MEMORY = min(total_ram // 4, 8 * 1024 * 1024 * 1024)
except:
MAX_PINNED_MEMORY = 4 * 1024 * 1024 * 1024 # Default 4GB
def pin_memory(tensor: torch.Tensor, key: str = None) -> torch.Tensor:
"""Pin a CPU tensor for faster transfers to GPU."""
global PINNED_MEMORY, TOTAL_PINNED_MEMORY, MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY < 0:
init_pinned_memory()
if tensor.device.type != 'cpu' or tensor.is_pinned():
return tensor
tensor_size = tensor.nelement() * tensor.element_size()
if TOTAL_PINNED_MEMORY + tensor_size > MAX_PINNED_MEMORY:
return tensor # Not enough room
try:
pinned = tensor.pin_memory()
TOTAL_PINNED_MEMORY += tensor_size
if key is not None:
PINNED_MEMORY[key] = (pinned, tensor_size)
return pinned
except:
return tensor
def unpin_memory(key: str = None):
"""Unpin memory associated with a key."""
global PINNED_MEMORY, TOTAL_PINNED_MEMORY
if key is not None and key in PINNED_MEMORY:
_, tensor_size = PINNED_MEMORY.pop(key)
TOTAL_PINNED_MEMORY -= tensor_size
def clear_pinned_memory():
"""Clear all pinned memory."""
global PINNED_MEMORY, TOTAL_PINNED_MEMORY
PINNED_MEMORY.clear()
TOTAL_PINNED_MEMORY = 0
# === Optimized tensor transfer with async streams ===
def cast_to(tensor: torch.Tensor, device: torch.device, dtype: torch.dtype = None,
copy: bool = False, non_blocking: bool = True, stream=None):
"""Optimized tensor transfer with optional async streaming."""
target_dtype = dtype if dtype is not None else tensor.dtype
# Fast path: no change needed
if tensor.device == device and tensor.dtype == target_dtype and not copy:
return tensor
# Use provided stream or get one
if stream is None and NUM_STREAMS > 0 and torch.cuda.is_available():
stream = get_offload_stream(device)
if stream is not None:
with torch.cuda.stream(stream):
return tensor.to(device=device, dtype=target_dtype, copy=copy, non_blocking=non_blocking)
else:
return tensor.to(device=device, dtype=target_dtype, copy=copy, non_blocking=non_blocking)
def is_intel_xpu() -> bool:
return cpu_state == CPUState.GPU and xpu_available
def is_nvidia() -> bool:
return cpu_state == CPUState.GPU and bool(torch.version.cuda)
def is_rocm() -> bool:
return cpu_state == CPUState.GPU and bool(torch.version.hip)
def get_torch_device() -> torch.device:
if directml_enabled:
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
if torch.cuda.is_available():
return torch.device(torch.cuda.current_device())
return torch.device("cpu")
def get_total_memory(dev: torch.device = None, torch_total_too: bool = False) -> Union[int, Tuple[int, int]]:
dev = dev or get_torch_device()
if hasattr(dev, "type") and dev.type in ("cpu", "mps"):
mem = psutil.virtual_memory().total
return (mem, mem) if torch_total_too else mem
if directml_enabled:
mem = 1024 ** 3
return (mem, mem) if torch_total_too else mem
if is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_torch = stats["reserved_bytes.all.current"]
mem_total = torch.xpu.get_device_properties(dev).total_memory
else:
stats = torch.cuda.memory_stats(dev)
mem_torch = stats["reserved_bytes.all.current"]
_, mem_total = torch.cuda.mem_get_info(dev)
return (mem_total, mem_torch) if torch_total_too else mem_total
_FREE_MEM_CACHE = {}
_FREE_MEM_CACHE_TTL = 0.1 # 100ms
def get_free_memory(dev: torch.device = None, torch_free_too: bool = False) -> Union[int, Tuple[int, int]]:
global _FREE_MEM_CACHE
dev = dev or get_torch_device()
# Simple caching to avoid high frequency blocking calls in sampling loop
import time
now = time.time()
cache_key = (str(dev), torch_free_too)
if cache_key in _FREE_MEM_CACHE:
val, ts = _FREE_MEM_CACHE[cache_key]
if now - ts < _FREE_MEM_CACHE_TTL:
return val
if hasattr(dev, "type") and dev.type in ("cpu", "mps"):
mem = psutil.virtual_memory().available
res = (mem, mem) if torch_free_too else mem
_FREE_MEM_CACHE[cache_key] = (res, now)
return res
if directml_enabled:
mem = 1024 ** 3
res = (mem, mem) if torch_free_too else mem
_FREE_MEM_CACHE[cache_key] = (res, now)
return res
if is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
active = stats["active_bytes.all.current"]
reserved = stats["reserved_bytes.all.current"]
free_torch = reserved - active
free_total = torch.xpu.get_device_properties(dev).total_memory - reserved + free_torch
else:
# torch.cuda.mem_get_info is a blocking sync on many Windows drivers
stats = torch.cuda.memory_stats(dev)
active = stats["active_bytes.all.current"]
reserved = stats["reserved_bytes.all.current"]
free_cuda, _ = torch.cuda.mem_get_info(dev)
free_torch = reserved - active
free_total = free_cuda + free_torch
res = (free_total, free_torch) if torch_free_too else free_total
_FREE_MEM_CACHE[cache_key] = (res, now)
return res
def soft_empty_cache(force: bool = False) -> None:
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif is_intel_xpu():
torch.xpu.empty_cache()
elif torch.cuda.is_available() and (force or is_nvidia()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# === torch.compile support (from ComfyUI for model optimization) ===
TORCH_COMPILE_ENABLED = False
COMPILED_MODELS = {}
def enable_torch_compile(enabled: bool = True):
"""Enable or disable torch.compile for model optimization."""
global TORCH_COMPILE_ENABLED
TORCH_COMPILE_ENABLED = enabled
if enabled:
logging.info("torch.compile enabled for model optimization")
def compile_model(model: torch.nn.Module, mode: str = "max-autotune-no-cudagraphs",
fullgraph: bool = False, dynamic: bool = True) -> torch.nn.Module:
"""Compile a model with torch.compile for faster inference.
Uses 'max-autotune-no-cudagraphs' by default. Avoid 'reduce-overhead'
as it enables CUDA graphs which cause assertion errors with dynamic
model state (LoRA patches, mixed dtypes, etc.).
Args:
model: The model to compile
mode: Compilation mode - "max-autotune-no-cudagraphs" (recommended),
"max-autotune", "default", or "reduce-overhead"
fullgraph: Whether to compile the full graph
dynamic: Whether to allow dynamic shapes
Returns:
Compiled model (or original if compilation fails)
"""
global COMPILED_MODELS
if not TORCH_COMPILE_ENABLED:
return model
# Check PyTorch version
if not hasattr(torch, 'compile'):
logging.warning("torch.compile not available (requires PyTorch 2.0+)")
return model
# Check if already compiled
model_id = id(model)
if model_id in COMPILED_MODELS:
return COMPILED_MODELS[model_id]
try:
# Use inductor backend for best performance
compiled = torch.compile(
model,
mode=mode,
fullgraph=fullgraph,
dynamic=dynamic,
backend="inductor"
)
COMPILED_MODELS[model_id] = compiled
logging.info(f"Model compiled successfully with mode={mode}")
return compiled
except Exception as e:
logging.warning(f"torch.compile failed: {e}")
return model
def clear_compiled_models():
"""Clear the compiled models cache."""
global COMPILED_MODELS
COMPILED_MODELS.clear()
# Initialize PyTorch attention and VAE dtype
try:
if is_nvidia() or is_rocm():
if int(torch.version.__version__[0]) >= 2:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
if is_nvidia() and torch.cuda.get_device_properties(0).major >= 8:
VAE_DTYPE = torch.bfloat16
elif is_rocm():
VAE_DTYPE = torch.bfloat16
except:
pass
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
if ENABLE_PYTORCH_ATTENTION and torch.cuda.is_available():
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
# Apply vram_state based on cpu_state
if cpu_state != CPUState.GPU:
vram_state = VRAMState.DISABLED
elif cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
total_vram = get_total_memory() / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info(f"VRAM: {total_vram:.0f} MB, RAM: {total_ram:.0f} MB, Device: {get_torch_device()}, VAE dtype: {VAE_DTYPE}")
# Model management
current_loaded_models = []
def module_size(module: torch.nn.Module) -> int:
return sum(t.nelement() * t.element_size() for t in module.state_dict().values())
class LoadedModel:
def __init__(self, model):
self.model = model
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
def __eq__(self, other):
return isinstance(other, LoadedModel) and self.model == other.model
def model_memory(self):
return self.model.model_size()
def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size()
def model_memory_required(self, device):
if hasattr(self.model, 'current_loaded_device') and device == self.model.current_loaded_device():
return self.model_offloaded_memory()
return self.model_memory()
def model_load(self, lowvram_model_memory: int = 0, force_patch_weights: bool = False):
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
try:
if hasattr(self.model, "patch_model_lowvram") and lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(
device_to=self.device, lowvram_model_memory=lowvram_model_memory,
force_patch_weights=force_patch_weights)
else:
# CRITICAL: parameter is patch_weights, not load_weights!
self.real_model = self.model.patch_model(device_to=self.device, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
self.weights_loaded = True
return self.real_model
def should_reload_model(self, force_patch_weights: bool = False) -> bool:
return force_patch_weights and self.model.lowvram_patch_counter > 0
def model_unload(self, unpatch_weights: bool = True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def model_use_more_vram(self, extra_memory: int) -> int:
return self.model.partially_load(self.device, extra_memory)
def minimum_inference_memory() -> int:
return 1024 * 1024 * 1024
def extra_reserved_memory() -> int:
return EXTRA_RESERVED_VRAM
def unload_model_clones(model, unload_weights_only: bool = True, force_unload: bool = True):
to_unload = [i for i in range(len(current_loaded_models) - 1, -1, -1)
if model.is_clone(current_loaded_models[i].model)]
if not to_unload:
return True
if not force_unload and unload_weights_only:
return None
for i in to_unload:
current_loaded_models.pop(i).model_unload(unpatch_weights=True)
return True
def free_memory(memory_required: int, device: torch.device, keep_loaded: list = []):
can_unload = [(sys.getrefcount(m.model), m.model_memory(), i)
for i, m in enumerate(current_loaded_models)
if m.device == device and m not in keep_loaded]
unloaded = []
for x in sorted(can_unload):
if not DISABLE_SMART_MEMORY and get_free_memory(device) > memory_required:
break
current_loaded_models[x[-1]].model_unload()
unloaded.append(x[-1])
for i in sorted(unloaded, reverse=True):
current_loaded_models.pop(i)
if unloaded:
soft_empty_cache()
def load_models_gpu(models: list, memory_required: int = 0, force_patch_weights: bool = False,
minimum_memory_required: int = None, force_full_load: bool = False):
global vram_state
# Handle mock objects in tests
if not isinstance(memory_required, int):
try:
memory_required = int(memory_required)
except Exception:
memory_required = 0
inference_memory = minimum_inference_memory()
if not isinstance(inference_memory, int):
try:
inference_memory = int(inference_memory)
except Exception:
inference_memory = 0
extra_mem = max(inference_memory, memory_required)
min_mem = minimum_memory_required or extra_mem
models_to_load, models_already_loaded = [], []
for x in set(models):
loaded_model = LoadedModel(x)
try:
idx = current_loaded_models.index(loaded_model)
loaded = current_loaded_models[idx]
if loaded.should_reload_model(force_patch_weights=force_patch_weights):
current_loaded_models.pop(idx).model_unload(unpatch_weights=True)
models_to_load.append(loaded_model)
else:
models_already_loaded.append(loaded)
except ValueError:
if hasattr(x, "model"):
logging.info(f"Loading {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if not models_to_load:
for d in set(m.device for m in models_already_loaded):
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
return
# Calculate and free memory
mem_required = {}
for m in models_to_load:
if unload_model_clones(m.model, unload_weights_only=True, force_unload=False):
mem_required[m.device] = mem_required.get(m.device, 0) + m.model_memory_required(m.device)
for device, mem in mem_required.items():
if device != torch.device("cpu"):
free_memory(mem * 1.3 + extra_mem, device, models_already_loaded)
for m in models_to_load:
weights_unloaded = unload_model_clones(m.model, unload_weights_only=False, force_unload=False)
if weights_unloaded is not None:
m.weights_loaded = not weights_unloaded
# Load models
for loaded_model in models_to_load:
torch_dev = loaded_model.model.load_device
vram_set = VRAMState.DISABLED if is_device_cpu(torch_dev) else vram_state
lowvram_mem = 0
if vram_set in (VRAMState.LOW_VRAM, VRAMState.NORMAL_VRAM) and not force_full_load:
model_size = loaded_model.model_memory_required(torch_dev)
# Handle mock objects in tests
if not isinstance(model_size, int):
try:
model_size = int(model_size)
except Exception:
model_size = 0
current_free = get_free_memory(torch_dev)
lowvram_mem = int(max(64 * 1024 * 1024, (current_free - 1024 * 1024 * 1024) / 1.3))
# Handle mock objects in tests
if not isinstance(current_free, int):
try:
current_free = int(current_free)
except Exception:
current_free = 10 * 1024 * 1024 * 1024 # 10GB fallback
if model_size <= current_free - inference_memory:
lowvram_mem = 0
if vram_set == VRAMState.NO_VRAM:
lowvram_mem = 64 * 1024 * 1024
loaded_model.model_load(lowvram_mem, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
def load_model_gpu(model):
load_models_gpu([model])
def cleanup_models(keep_clone_weights_loaded: bool = False):
to_delete = [i for i in range(len(current_loaded_models) - 1, -1, -1)
if sys.getrefcount(current_loaded_models[i].model) <= 2 and
(not keep_clone_weights_loaded or sys.getrefcount(current_loaded_models[i].real_model) <= 3)]
for i in to_delete:
current_loaded_models.pop(i).model_unload()
def unload_all_models():
free_memory(int(1e30), get_torch_device())
# Device utilities
def is_device_type(device, dtype: str) -> bool:
return hasattr(device, "type") and device.type == dtype
def is_device_cpu(device) -> bool:
return is_device_type(device, "cpu")
def is_device_mps(device) -> bool:
return is_device_type(device, "mps")
def is_device_cuda(device) -> bool:
return is_device_type(device, "cuda")
def cpu_mode() -> bool:
return cpu_state == CPUState.CPU
def mps_mode() -> bool:
return cpu_state == CPUState.MPS
# Dtype utilities
def dtype_size(dtype) -> int:
if dtype in (torch.float16, torch.bfloat16):
return 2
if dtype == torch.float32:
return 4
return getattr(dtype, 'itemsize', 4)
def supports_dtype(device, dtype) -> bool:
if dtype == torch.float32:
return True
return not is_device_cpu(device)
def supports_cast(device, dtype) -> bool:
if dtype in (torch.float32, torch.float16, torch.bfloat16):
return True
if directml_enabled or is_device_mps(device):
return False
return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
def is_fp8_supported(device=None) -> bool:
"""Check if FP8 (float8_e4m3fn) is supported on the device."""
if device is None:
device = get_torch_device()
if not is_device_cuda(device):
return False
# FP8 requires compute capability 8.9+ (Ada Lovelace) or 9.0+ (Hopper)
try:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(device)
if major >= 9:
return True
if major == 8 and minor >= 9:
return True
except:
pass
return False
def cast_to_fp8(tensor: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
"""Cast a tensor to FP8 (float8_e4m3fn)."""
if not hasattr(torch, "float8_e4m3fn"):
return tensor.to(torch.float16) # Fallback
# Scale if needed (scaling is often used for better precision in FP8)
if scale != 1.0:
tensor = tensor * scale
return tensor.to(torch.float8_e4m3fn)
def cast_to_device(tensor, device, dtype, copy: bool = False):
non_blocking = not is_device_mps(device)
can_cast = tensor.dtype in (torch.float32, torch.float16) or \
(tensor.dtype == torch.bfloat16 and (is_device_cuda(device) or is_intel_xpu()))
if can_cast:
if copy and tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def pick_weight_dtype(dtype, fallback_dtype, device):
dtype = dtype or fallback_dtype
if dtype_size(dtype) > dtype_size(fallback_dtype):
dtype = fallback_dtype
if not supports_cast(device, dtype):
dtype = fallback_dtype
return dtype
# UNet/VAE/text encoder device helpers
def unet_offload_device() -> torch.device:
return get_torch_device() if vram_state == VRAMState.HIGH_VRAM else torch.device("cpu")
def unet_inital_load_device(parameters, dtype) -> torch.device:
if vram_state == VRAMState.HIGH_VRAM or DISABLE_SMART_MEMORY:
return get_torch_device() if vram_state == VRAMState.HIGH_VRAM else torch.device("cpu")
model_size = dtype_size(dtype) * parameters
if get_free_memory(get_torch_device()) > get_free_memory(torch.device("cpu")) and model_size < get_free_memory(get_torch_device()):
return get_torch_device()
return torch.device("cpu")
def unet_dtype(device=None, model_params: int = 0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if should_use_fp16(device=device, model_params=model_params, manual_cast=True) and torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True) and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
if should_use_fp16(inference_device, prioritize_performance=False) and weight_dtype == torch.float16:
return None
if should_use_bf16(inference_device) and weight_dtype == torch.bfloat16:
return None
if should_use_fp16(inference_device, prioritize_performance=False) and torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(inference_device) and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32
def text_encoder_offload_device() -> torch.device:
return torch.device("cpu")
def text_encoder_device() -> torch.device:
if vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) and should_use_fp16(prioritize_performance=False):
return get_torch_device()
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size: int = 0):
if load_device == offload_device or model_size <= 1024 ** 3 or is_device_mps(load_device):
return offload_device
if get_free_memory(load_device) > get_free_memory(offload_device) * 0.5 and model_size * 1.2 < get_free_memory(load_device):
return load_device
return offload_device
def text_encoder_dtype(device=None):
if is_device_cpu(device):
return torch.float16
return torch.bfloat16 if should_use_bf16(device) else torch.float16
def intermediate_device() -> torch.device:
return torch.device("cpu")
def vae_device() -> torch.device:
return get_torch_device()
def vae_offload_device() -> torch.device:
return torch.device("cpu")
def vae_dtype():
return VAE_DTYPE
def get_autocast_device(dev) -> str:
return getattr(dev, "type", "cuda")
# Feature detection
def sageattention_enabled() -> bool:
if cpu_state != CPUState.GPU or is_intel_xpu() or directml_enabled or is_rocm():
return False
return SAGEATTENTION_IS_AVAILABLE
def sageattention_enabled_vae() -> bool:
return sageattention_enabled() and SAGEATTENTION_ENABLED_VAE
def spargeattn_enabled() -> bool:
if cpu_state != CPUState.GPU or is_intel_xpu() or directml_enabled or is_rocm():
return False
if torch.cuda.is_available():
try:
if torch.cuda.get_device_capability()[0] >= 12:
return False
except:
pass
return SPARGEATTN_IS_AVAILABLE
def spargeattn_enabled_vae() -> bool:
return spargeattn_enabled() and SPARGEATTN_ENABLED_VAE
def xformers_enabled() -> bool:
if cpu_state != CPUState.GPU or is_intel_xpu() or directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae() -> bool:
return xformers_enabled() and XFORMERS_ENABLED_VAE
def pytorch_attention_enabled() -> bool:
return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_flash_attention() -> bool:
return ENABLE_PYTORCH_ATTENTION and (is_nvidia() or is_rocm())
def device_supports_non_blocking(device) -> bool:
return not is_device_mps(device)
# FP16/BF16 support detection
def should_use_fp16(device=None, model_params: int = 0, prioritize_performance: bool = True, manual_cast: bool = False) -> bool:
if FORCE_FP16:
return True
if FORCE_FP32 or directml_enabled or cpu_mode():
return False
if device and is_device_cpu(device):
return False
if mps_mode() or (device and is_device_mps(device)):
return True
if is_intel_xpu() or is_rocm():
return True
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties("cuda")
if props.major >= 8:
return True
if props.major < 6:
return False
# Check 10-series cards
fp16_works = any(x in props.name.lower() for x in ["1080", "1070", "titan x", "p3000", "p4000", "p5000", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"])
if fp16_works or manual_cast:
# Handle mock objects in tests
try:
free_mem = int(get_free_memory())
min_inf_mem = int(minimum_inference_memory())
except Exception:
free_mem = 10 * 1024 * 1024 * 1024
min_inf_mem = 0
if not prioritize_performance or model_params * 4 > free_mem * 0.9 - min_inf_mem:
return True
if props.major < 7:
return False
# Exclude 16-series
return not any(x in props.name for x in ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"])
def should_use_bf16(device=None, model_params: int = 0, prioritize_performance: bool = True, manual_cast: bool = False) -> bool:
if FORCE_FP32 or directml_enabled or cpu_mode() or mps_mode():
return False
if device and (is_device_cpu(device) or is_device_mps(device)):
return False
if is_intel_xpu():
return True
if is_rocm():
try:
return torch.cuda.is_bf16_supported()
except:
return False
device = device or torch.device("cuda")
if torch.cuda.get_device_properties(device).major >= 8:
return True
try:
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
# Handle mock objects in tests
try:
free_mem = int(get_free_memory())
min_inf_mem = int(minimum_inference_memory())
except Exception:
free_mem = 10 * 1024 * 1024 * 1024
min_inf_mem = 0
if not prioritize_performance or model_params * 4 > free_mem * 0.9 - min_inf_mem:
return True
except:
pass
return False
def resolve_lowvram_weight(weight, model, key):
return weight