#!/usr/bin/env python3 """ Memory Manager for BackgroundFX Pro - Safe on CPU/CUDA/MPS (mostly CUDA/T4 on Spaces) - Accepts `device` as str or torch.device - Optional per-process VRAM cap (env or method) - Detailed usage reporting (CPU/RAM + VRAM + torch allocator) - Light and aggressive cleanup paths - Background monitor (optional) Env switches: BFX_DISABLE_LIMIT=1 -> do not set VRAM fraction automatically BFX_CUDA_FRACTION=0.80 -> fraction to cap per-process VRAM (0.10..0.95) """ from __future__ import annotations import gc import os import time import logging import threading from typing import Dict, Any, Optional, Callable # Optional deps try: import psutil except Exception: # pragma: no cover psutil = None try: import torch except Exception: # pragma: no cover torch = None logger = logging.getLogger(__name__) # ---- local exception to avoid shadowing built-in MemoryError ---- class MemoryManagerError(Exception): pass def _bytes_to_gb(x: int | float) -> float: try: return float(x) / (1024**3) except Exception: return 0.0 def _normalize_device(dev) -> "torch.device": if torch is None: # fake CPU device class _Fake: type = "cpu" index = None return _Fake() # type: ignore[return-value] if isinstance(dev, str): return torch.device(dev) if hasattr(dev, "type"): return dev # default CPU return torch.device("cpu") def _cuda_index(device) -> Optional[int]: if getattr(device, "type", "cpu") != "cuda": return None idx = getattr(device, "index", None) if idx is None: # normalize bare "cuda" to 0 return 0 return int(idx) class MemoryManager: """ Comprehensive memory management with VRAM cap + cleanup utilities. """ def __init__(self, device, memory_limit_gb: Optional[float] = None): self.device = _normalize_device(device) self.device_type = getattr(self.device, "type", "cpu") self.cuda_idx = _cuda_index(self.device) self.gpu_available = bool( torch and self.device_type == "cuda" and torch.cuda.is_available() ) self.mps_available = bool( torch and self.device_type == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() ) self.memory_limit_gb = memory_limit_gb self.cleanup_callbacks: list[Callable] = [] self.monitoring_active = False self.monitoring_thread: Optional[threading.Thread] = None self.stats = { "cleanup_count": 0, "peak_memory_usage": 0.0, "total_allocated": 0.0, "total_freed": 0.0, } self.applied_fraction: Optional[float] = None self._initialize_memory_limits() self._maybe_apply_vram_fraction() logger.info(f"MemoryManager initialized (device={self.device}, cuda={self.gpu_available})") # ------------------------------- # init helpers # ------------------------------- def _initialize_memory_limits(self): try: if self.gpu_available: props = torch.cuda.get_device_properties(self.cuda_idx or 0) total_gb = _bytes_to_gb(props.total_memory) if self.memory_limit_gb is None: self.memory_limit_gb = max(0.5, total_gb * 0.80) # default 80% logger.info( f"CUDA memory limit baseline ~{self.memory_limit_gb:.1f}GB " f"(device total {total_gb:.1f}GB)" ) elif self.mps_available: vm = psutil.virtual_memory() if psutil else None total_gb = _bytes_to_gb(vm.total) if vm else 0.0 if self.memory_limit_gb is None: self.memory_limit_gb = max(0.5, total_gb * 0.50) logger.info(f"MPS memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") else: vm = psutil.virtual_memory() if psutil else None total_gb = _bytes_to_gb(vm.total) if vm else 0.0 if self.memory_limit_gb is None: self.memory_limit_gb = max(0.5, total_gb * 0.60) logger.info(f"CPU memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") except Exception as e: logger.warning(f"Memory limit init failed: {e}") if self.memory_limit_gb is None: self.memory_limit_gb = 4.0 # conservative fallback def _maybe_apply_vram_fraction(self): if not self.gpu_available or torch is None: return if os.environ.get("BFX_DISABLE_LIMIT", ""): return frac_env = os.environ.get("BFX_CUDA_FRACTION", "").strip() try: fraction = float(frac_env) if frac_env else 0.80 except Exception: fraction = 0.80 applied = self.limit_cuda_memory(fraction=fraction) if applied: logger.info(f"Per-process CUDA memory fraction set to {applied:.2f} on device {self.cuda_idx or 0}") # ------------------------------- # public API # ------------------------------- def get_memory_usage(self) -> Dict[str, Any]: usage: Dict[str, Any] = { "device_type": self.device_type, "memory_limit_gb": self.memory_limit_gb, "timestamp": time.time(), } # CPU / system if psutil: try: vm = psutil.virtual_memory() usage.update( dict( system_total_gb=round(_bytes_to_gb(vm.total), 3), system_available_gb=round(_bytes_to_gb(vm.available), 3), system_used_gb=round(_bytes_to_gb(vm.used), 3), system_percent=float(vm.percent), ) ) swap = psutil.swap_memory() usage.update( dict( swap_total_gb=round(_bytes_to_gb(swap.total), 3), swap_used_gb=round(_bytes_to_gb(swap.used), 3), swap_percent=float(swap.percent), ) ) proc = psutil.Process() mi = proc.memory_info() usage.update( dict( process_rss_gb=round(_bytes_to_gb(mi.rss), 3), process_vms_gb=round(_bytes_to_gb(mi.vms), 3), ) ) except Exception as e: logger.debug(f"psutil stats error: {e}") # GPU if self.gpu_available and torch is not None: try: # mem_get_info returns (free, total) in bytes free_b, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) used_b = total_b - free_b usage.update( dict( vram_total_gb=round(_bytes_to_gb(total_b), 3), vram_used_gb=round(_bytes_to_gb(used_b), 3), vram_free_gb=round(_bytes_to_gb(free_b), 3), vram_used_percent=float(used_b / total_b * 100.0) if total_b else 0.0, ) ) except Exception as e: logger.debug(f"mem_get_info failed: {e}") # torch allocator stats try: idx = self.cuda_idx or 0 allocated = torch.cuda.memory_allocated(idx) reserved = torch.cuda.memory_reserved(idx) usage["torch_allocated_gb"] = round(_bytes_to_gb(allocated), 3) usage["torch_reserved_gb"] = round(_bytes_to_gb(reserved), 3) # inactive split (2.x) try: inactive = torch.cuda.memory_stats(idx).get("inactive_split_bytes.all.current", 0) usage["torch_inactive_split_gb"] = round(_bytes_to_gb(inactive), 3) except Exception: pass except Exception as e: logger.debug(f"allocator stats failed: {e}") usage["applied_fraction"] = self.applied_fraction # Update peak tracker current = usage.get("vram_used_gb", usage.get("system_used_gb", 0.0)) try: if float(current) > float(self.stats["peak_memory_usage"]): self.stats["peak_memory_usage"] = float(current) except Exception: pass return usage def limit_cuda_memory(self, fraction: Optional[float] = None, max_gb: Optional[float] = None) -> Optional[float]: if not self.gpu_available or torch is None: return None # derive fraction from max_gb if provided if max_gb is not None: try: _, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) total_gb = _bytes_to_gb(total_b) if total_gb <= 0: return None fraction = min(max(0.10, max_gb / total_gb), 0.95) except Exception as e: logger.debug(f"fraction from max_gb failed: {e}") return None if fraction is None: fraction = 0.80 fraction = float(max(0.10, min(0.95, fraction))) try: torch.cuda.set_per_process_memory_fraction(fraction, device=self.cuda_idx or 0) self.applied_fraction = fraction return fraction except Exception as e: logger.debug(f"set_per_process_memory_fraction failed: {e}") return None def cleanup(self) -> None: """Light cleanup used frequently between steps.""" try: gc.collect() except Exception: pass if self.gpu_available and torch is not None: try: torch.cuda.empty_cache() except Exception: pass self.stats["cleanup_count"] += 1 def cleanup_basic(self) -> None: """Alias kept for compatibility.""" self.cleanup() def cleanup_aggressive(self) -> None: """Aggressive cleanup for OOM recovery or big scene switches.""" if self.gpu_available and torch is not None: try: torch.cuda.synchronize(self.cuda_idx or 0) except Exception: pass try: torch.cuda.empty_cache() except Exception: pass try: torch.cuda.reset_peak_memory_stats(self.cuda_idx or 0) except Exception: pass try: if hasattr(torch.cuda, "ipc_collect"): torch.cuda.ipc_collect() except Exception: pass try: gc.collect(); gc.collect() except Exception: pass self.stats["cleanup_count"] += 1 def register_cleanup_callback(self, callback: Callable): self.cleanup_callbacks.append(callback) def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None): if self.monitoring_active: logger.warning("Memory monitoring already active") return self.monitoring_active = True def loop(): while self.monitoring_active: try: pressure = self.check_memory_pressure() if pressure["under_pressure"]: logger.warning( f"Memory pressure: {pressure['pressure_level']} " f"({pressure['usage_percent']:.1f}%)" ) if pressure_callback: try: pressure_callback(pressure) except Exception as e: logger.error(f"Pressure callback failed: {e}") if pressure["pressure_level"] == "critical": self.cleanup_aggressive() except Exception as e: logger.error(f"Memory monitoring error: {e}") time.sleep(interval_seconds) self.monitoring_thread = threading.Thread(target=loop, daemon=True) self.monitoring_thread.start() logger.info(f"Memory monitoring started (interval: {interval_seconds}s)") def stop_monitoring(self): if self.monitoring_active: self.monitoring_active = False if self.monitoring_thread and self.monitoring_thread.is_alive(): self.monitoring_thread.join(timeout=5.0) logger.info("Memory monitoring stopped") def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]: usage = self.get_memory_usage() info = { "under_pressure": False, "pressure_level": "normal", "usage_percent": 0.0, "recommendations": [], } if self.gpu_available: percent = usage.get("vram_used_percent", 0.0) info["usage_percent"] = percent if percent >= threshold_percent: info["under_pressure"] = True if percent >= 95: info["pressure_level"] = "critical" info["recommendations"] += [ "Run aggressive memory cleanup", "Reduce frame cache / chunk size", "Lower resolution or disable previews", ] else: info["pressure_level"] = "warning" info["recommendations"] += [ "Run cleanup", "Monitor memory usage", "Reduce keyframe interval", ] else: percent = usage.get("system_percent", 0.0) info["usage_percent"] = percent if percent >= threshold_percent: info["under_pressure"] = True if percent >= 95: info["pressure_level"] = "critical" info["recommendations"] += [ "Close other processes", "Reduce resolution", "Split video into chunks", ] else: info["pressure_level"] = "warning" info["recommendations"] += [ "Run cleanup", "Monitor usage", "Reduce processing footprint", ] return info def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]: bytes_per_frame = video_width * video_height * 3 overhead_multiplier = 3.0 # masks/intermediates frames_gb = _bytes_to_gb(bytes_per_frame * frames_in_memory * overhead_multiplier) estimate = { "frames_memory_gb": round(frames_gb, 3), "model_memory_gb": 4.0, "system_overhead_gb": 2.0, } estimate["total_estimated_gb"] = round( estimate["frames_memory_gb"] + estimate["model_memory_gb"] + estimate["system_overhead_gb"], 3 ) return estimate def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]: estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory) usage = self.get_memory_usage() if self.gpu_available: available = usage.get("vram_free_gb", 0.0) else: available = usage.get("system_available_gb", 0.0) can = estimate["total_estimated_gb"] <= available return { "can_process": can, "estimated_memory_gb": estimate["total_estimated_gb"], "available_memory_gb": available, "memory_margin_gb": round(available - estimate["total_estimated_gb"], 3), "recommendations": [] if can else [ "Reduce resolution or duration", "Process in smaller chunks", "Run aggressive cleanup before start", ], } def get_stats(self) -> Dict[str, Any]: return { "cleanup_count": self.stats["cleanup_count"], "peak_memory_usage_gb": self.stats["peak_memory_usage"], "device_type": self.device_type, "memory_limit_gb": self.memory_limit_gb, "applied_fraction": self.applied_fraction, "monitoring_active": self.monitoring_active, "callbacks_registered": len(self.cleanup_callbacks), } def __del__(self): try: self.stop_monitoring() self.cleanup_aggressive() except Exception: pass