talkingAvater_bgk / core /optimization /cold_start_optimization.py
oKen38461's picture
ColdStartOptimizerクラスのGradio起動パラメータから不要なキュー設定を削除し、最大スレッド数を40に増加させました。
8ab20fc
"""
Cold Start Optimization for DittoTalkingHead
Reduces model loading time and I/O overhead
"""
import os
import shutil
import time
from pathlib import Path
from typing import Dict, Any, Optional
import pickle
import torch
class ColdStartOptimizer:
"""
Optimizes cold start time by using persistent storage and efficient loading
"""
def __init__(self, persistent_dir: str = "/tmp/persistent_model_cache"):
"""
Initialize cold start optimizer
Args:
persistent_dir: Directory for persistent storage (survives restarts)
"""
self.persistent_dir = Path(persistent_dir)
self.persistent_dir.mkdir(parents=True, exist_ok=True)
# Hugging Face Spaces persistent paths
self.hf_persistent_paths = [
"/data", # Primary persistent storage
"/tmp/persistent", # Fallback
]
# Model cache settings
self.model_cache = {}
self.load_times = {}
def get_persistent_path(self) -> Path:
"""
Get the best available persistent path
Returns:
Path to persistent storage
"""
# Check Hugging Face Spaces persistent directories
for path in self.hf_persistent_paths:
if os.path.exists(path) and os.access(path, os.W_OK):
return Path(path) / "model_cache"
# Fallback to configured directory
return self.persistent_dir
def setup_persistent_model_cache(self, source_dir: str) -> bool:
"""
Set up persistent model cache
Args:
source_dir: Source directory containing models
Returns:
True if successful
"""
persistent_path = self.get_persistent_path()
persistent_path.mkdir(parents=True, exist_ok=True)
source_path = Path(source_dir)
if not source_path.exists():
print(f"Source directory {source_dir} not found")
return False
# Copy models to persistent storage if not already there
model_files = list(source_path.glob("**/*.pth")) + \
list(source_path.glob("**/*.pkl")) + \
list(source_path.glob("**/*.onnx")) + \
list(source_path.glob("**/*.trt"))
copied = 0
for model_file in model_files:
relative_path = model_file.relative_to(source_path)
target_path = persistent_path / relative_path
if not target_path.exists():
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(model_file, target_path)
copied += 1
print(f"Copied {relative_path} to persistent storage")
print(f"Persistent cache setup complete. Copied {copied} new files.")
return True
def load_model_cached(
self,
model_path: str,
load_func: callable,
cache_key: Optional[str] = None
) -> Any:
"""
Load model with caching
Args:
model_path: Path to model file
load_func: Function to load the model
cache_key: Optional cache key (defaults to model_path)
Returns:
Loaded model
"""
cache_key = cache_key or model_path
# Check in-memory cache first
if cache_key in self.model_cache:
print(f"✅ Loaded {cache_key} from memory cache")
return self.model_cache[cache_key]
# Check persistent storage
persistent_path = self.get_persistent_path()
model_name = Path(model_path).name
persistent_model_path = persistent_path / model_name
start_time = time.time()
if persistent_model_path.exists():
# Load from persistent storage
print(f"Loading {model_name} from persistent storage...")
model = load_func(str(persistent_model_path))
else:
# Load from original path
print(f"Loading {model_name} from original location...")
model = load_func(model_path)
# Try to copy to persistent storage
try:
shutil.copy2(model_path, persistent_model_path)
print(f"Cached {model_name} to persistent storage")
except Exception as e:
print(f"Warning: Could not cache to persistent storage: {e}")
load_time = time.time() - start_time
self.load_times[cache_key] = load_time
# Cache in memory
self.model_cache[cache_key] = model
print(f"✅ Loaded {cache_key} in {load_time:.2f}s")
return model
def preload_models(self, model_configs: Dict[str, Dict[str, Any]]):
"""
Preload multiple models in parallel
Args:
model_configs: Dictionary of model configurations
{
'model_name': {
'path': 'path/to/model',
'load_func': callable,
'priority': int (0-10)
}
}
"""
# Sort by priority
sorted_models = sorted(
model_configs.items(),
key=lambda x: x[1].get('priority', 5),
reverse=True
)
for model_name, config in sorted_models:
try:
self.load_model_cached(
config['path'],
config['load_func'],
cache_key=model_name
)
except Exception as e:
print(f"Error preloading {model_name}: {e}")
def optimize_gradio_settings(self) -> Dict[str, Any]:
"""
Get optimized Gradio settings for faster response
Returns:
Gradio launch parameters
"""
return {
'max_threads': 40, # Increase parallel processing
'show_error': True,
'server_name': '0.0.0.0',
'server_port': 7860,
'share': False, # Disable share link for faster startup
}
def get_optimization_stats(self) -> Dict[str, Any]:
"""
Get cold start optimization statistics
Returns:
Optimization statistics
"""
persistent_path = self.get_persistent_path()
# Count cached files
cached_files = 0
total_size = 0
if persistent_path.exists():
for file in persistent_path.rglob("*"):
if file.is_file():
cached_files += 1
total_size += file.stat().st_size
return {
'persistent_path': str(persistent_path),
'cached_models': len(self.model_cache),
'cached_files': cached_files,
'total_cache_size_mb': total_size / (1024 * 1024),
'load_times': self.load_times,
'average_load_time': sum(self.load_times.values()) / len(self.load_times) if self.load_times else 0
}
def clear_memory_cache(self):
"""Clear in-memory model cache"""
self.model_cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Memory cache cleared")
def setup_streaming_response(self) -> Dict[str, Any]:
"""
Set up configuration for streaming responses
Returns:
Streaming configuration
"""
return {
'stream_output': True,
'buffer_size': 8192, # 8KB buffer
'chunk_size': 1024, # 1KB chunks
'enable_compression': True,
'compression_level': 6 # Balanced compression
}