infinitetalk / utils /gpu_manager.py
ShalomKing's picture
Upload folder using huggingface_hub
38572a2 verified
raw
history blame
6.39 kB
"""
GPU Memory Manager for InfiniteTalk
Handles memory monitoring, cleanup, and optimization
"""
import torch
import logging
from typing import Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GPUManager:
"""Manages GPU memory usage and optimization"""
def __init__(self, max_memory_gb=65):
"""
Initialize GPU Manager
Args:
max_memory_gb: Maximum memory threshold in GB (default 65GB for 70GB H200)
"""
self.max_memory_bytes = max_memory_gb * 1024 ** 3
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_memory_usage(self):
"""
Get current GPU memory usage
Returns:
dict with allocated, reserved, and free memory in GB
"""
if not torch.cuda.is_available():
return {"allocated": 0, "reserved": 0, "free": 0}
allocated = torch.cuda.memory_allocated() / 1024 ** 3
reserved = torch.cuda.memory_reserved() / 1024 ** 3
total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
free = total - allocated
return {
"allocated": round(allocated, 2),
"reserved": round(reserved, 2),
"free": round(free, 2),
"total": round(total, 2)
}
def print_memory_usage(self, prefix=""):
"""Print current memory usage"""
usage = self.get_memory_usage()
logger.info(
f"{prefix}GPU Memory - "
f"Allocated: {usage['allocated']}GB, "
f"Reserved: {usage['reserved']}GB, "
f"Free: {usage['free']}GB"
)
def check_memory_threshold(self):
"""
Check if memory usage exceeds threshold
Returns:
bool: True if within safe limits, False if exceeded
"""
if not torch.cuda.is_available():
return True
allocated = torch.cuda.memory_allocated()
if allocated > self.max_memory_bytes:
logger.warning(
f"Memory threshold exceeded! "
f"Allocated: {allocated / 1024**3:.2f}GB, "
f"Threshold: {self.max_memory_bytes / 1024**3:.2f}GB"
)
return False
return True
def cleanup(self):
"""Perform garbage collection and CUDA cache cleanup"""
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info("GPU memory cleaned up")
self.print_memory_usage("After cleanup - ")
def optimize_model_for_inference(self, model):
"""
Apply optimizations to model for inference
Args:
model: PyTorch model to optimize
Returns:
Optimized model
"""
model.eval()
# Enable gradient checkpointing if available
if hasattr(model, "enable_gradient_checkpointing"):
model.enable_gradient_checkpointing()
# Use FP16 for inference to save memory
if torch.cuda.is_available() and hasattr(model, "half"):
logger.info("Converting model to FP16")
model = model.half()
return model
def enable_memory_efficient_attention(self):
"""Enable memory-efficient attention mechanisms"""
try:
import xformers
logger.info("xformers available - memory efficient attention enabled")
return True
except ImportError:
logger.warning("xformers not available - using standard attention")
return False
def estimate_inference_memory(self, resolution="480p", duration_seconds=10):
"""
Estimate memory requirements for inference
Args:
resolution: Video resolution (480p or 720p)
duration_seconds: Video duration in seconds
Returns:
Estimated memory in GB
"""
base_memory = 20 # Base model memory
if resolution == "720p":
per_second_memory = 1.5
else: # 480p
per_second_memory = 0.8
estimated = base_memory + (duration_seconds * per_second_memory)
logger.info(
f"Estimated memory for {resolution} video ({duration_seconds}s): "
f"{estimated:.2f}GB"
)
return estimated
def should_use_chunking(self, video_duration, resolution="480p"):
"""
Determine if chunked processing should be used
Args:
video_duration: Duration in seconds
resolution: Video resolution
Returns:
bool: True if chunking recommended
"""
estimated_memory = self.estimate_inference_memory(resolution, video_duration)
# Use chunking if estimated memory exceeds 50GB
return estimated_memory > 50
def get_optimal_chunk_size(self, resolution="480p"):
"""
Get optimal chunk size for video processing
Args:
resolution: Video resolution
Returns:
Optimal chunk size in seconds
"""
if resolution == "720p":
return 10 # 10 second chunks for 720p
else:
return 15 # 15 second chunks for 480p
@staticmethod
def calculate_duration_for_zerogpu(video_duration, resolution="480p"):
"""
Calculate ZeroGPU duration parameter
Args:
video_duration: Duration of video in seconds
resolution: Video resolution
Returns:
Recommended duration for @spaces.GPU decorator
"""
base_time = 60 # Base time for model loading
# Processing time per second of video
if resolution == "720p":
processing_rate = 3.5
else: # 480p
processing_rate = 2.5
# Add safety margin of 1.2x
estimated_time = base_time + (video_duration * processing_rate)
duration = int(estimated_time * 1.2)
# Cap at 300 seconds for free tier (300s ZeroGPU = 10 min real time)
duration = min(duration, 300)
logger.info(
f"Calculated ZeroGPU duration: {duration}s for "
f"{video_duration}s {resolution} video"
)
return duration
# Global instance
gpu_manager = GPUManager()