Spaces:
Runtime error
Runtime error
""" | |
Inference Cache System for DittoTalkingHead | |
Caches video generation results for faster repeated processing | |
""" | |
import hashlib | |
import json | |
import os | |
import pickle | |
import time | |
from pathlib import Path | |
from typing import Optional, Dict, Any, Tuple, Union | |
from functools import lru_cache | |
import shutil | |
from datetime import datetime, timedelta | |
class InferenceCache: | |
""" | |
Cache system for video generation results | |
Supports both memory and file-based caching | |
""" | |
def __init__( | |
self, | |
cache_dir: str = "/tmp/inference_cache", | |
memory_cache_size: int = 100, | |
file_cache_size_gb: float = 10.0, | |
ttl_hours: int = 24 | |
): | |
""" | |
Initialize inference cache | |
Args: | |
cache_dir: Directory for file-based cache | |
memory_cache_size: Maximum number of items in memory cache | |
file_cache_size_gb: Maximum size of file cache in GB | |
ttl_hours: Time to live for cache entries in hours | |
""" | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(parents=True, exist_ok=True) | |
self.memory_cache_size = memory_cache_size | |
self.file_cache_size_bytes = int(file_cache_size_gb * 1024 * 1024 * 1024) | |
self.ttl_seconds = ttl_hours * 3600 | |
# Metadata file for managing cache | |
self.metadata_file = self.cache_dir / "cache_metadata.json" | |
self.metadata = self._load_metadata() | |
# In-memory cache | |
self._memory_cache = {} | |
self._access_times = {} | |
# Clean up expired entries on initialization | |
self._cleanup_expired() | |
def _load_metadata(self) -> Dict[str, Any]: | |
"""Load cache metadata""" | |
if self.metadata_file.exists(): | |
try: | |
with open(self.metadata_file, 'r') as f: | |
return json.load(f) | |
except: | |
return {} | |
return {} | |
def _save_metadata(self): | |
"""Save cache metadata""" | |
with open(self.metadata_file, 'w') as f: | |
json.dump(self.metadata, f, indent=2) | |
def generate_cache_key( | |
self, | |
audio_path: str, | |
image_path: str, | |
**kwargs | |
) -> str: | |
""" | |
Generate unique cache key based on input parameters | |
Args: | |
audio_path: Path to audio file | |
image_path: Path to image file | |
**kwargs: Additional parameters affecting output | |
Returns: | |
SHA-256 hash as cache key | |
""" | |
# Read file contents for hashing | |
with open(audio_path, 'rb') as f: | |
audio_hash = hashlib.sha256(f.read()).hexdigest() | |
with open(image_path, 'rb') as f: | |
image_hash = hashlib.sha256(f.read()).hexdigest() | |
# Include relevant parameters in key | |
key_data = { | |
'audio': audio_hash, | |
'image': image_hash, | |
'resolution': kwargs.get('resolution', '320x320'), | |
'steps': kwargs.get('steps', 25), | |
'seed': kwargs.get('seed', None) | |
} | |
# Generate final key | |
key_str = json.dumps(key_data, sort_keys=True) | |
return hashlib.sha256(key_str.encode()).hexdigest() | |
def get_from_memory(self, cache_key: str) -> Optional[str]: | |
""" | |
Get video path from memory cache | |
Args: | |
cache_key: Cache key | |
Returns: | |
Video file path if found, None otherwise | |
""" | |
if cache_key in self._memory_cache: | |
self._access_times[cache_key] = time.time() | |
return self._memory_cache[cache_key] | |
return None | |
def get_from_file(self, cache_key: str) -> Optional[str]: | |
""" | |
Get video path from file cache | |
Args: | |
cache_key: Cache key | |
Returns: | |
Video file path if found, None otherwise | |
""" | |
if cache_key not in self.metadata: | |
return None | |
entry = self.metadata[cache_key] | |
# Check expiration | |
if time.time() > entry['expires_at']: | |
self._remove_cache_entry(cache_key) | |
return None | |
# Check if file exists | |
video_path = self.cache_dir / entry['filename'] | |
if not video_path.exists(): | |
self._remove_cache_entry(cache_key) | |
return None | |
# Update access time | |
self.metadata[cache_key]['last_access'] = time.time() | |
self._save_metadata() | |
# Add to memory cache | |
self._add_to_memory_cache(cache_key, str(video_path)) | |
return str(video_path) | |
def get(self, cache_key: str) -> Optional[str]: | |
""" | |
Get video from cache (memory first, then file) | |
Args: | |
cache_key: Cache key | |
Returns: | |
Video file path if found, None otherwise | |
""" | |
# Try memory cache first | |
result = self.get_from_memory(cache_key) | |
if result: | |
return result | |
# Try file cache | |
return self.get_from_file(cache_key) | |
def put( | |
self, | |
cache_key: str, | |
video_path: str, | |
**metadata | |
) -> bool: | |
""" | |
Store video in cache | |
Args: | |
cache_key: Cache key | |
video_path: Path to generated video | |
**metadata: Additional metadata to store | |
Returns: | |
True if stored successfully | |
""" | |
try: | |
# Copy video to cache directory | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
cache_filename = f"{cache_key[:8]}_{timestamp}.mp4" | |
cache_video_path = self.cache_dir / cache_filename | |
shutil.copy2(video_path, cache_video_path) | |
# Store metadata | |
self.metadata[cache_key] = { | |
'filename': cache_filename, | |
'created_at': time.time(), | |
'expires_at': time.time() + self.ttl_seconds, | |
'last_access': time.time(), | |
'size_bytes': os.path.getsize(cache_video_path), | |
'metadata': metadata | |
} | |
# Check cache size and clean if needed | |
self._check_cache_size() | |
# Save metadata | |
self._save_metadata() | |
# Add to memory cache | |
self._add_to_memory_cache(cache_key, str(cache_video_path)) | |
return True | |
except Exception as e: | |
print(f"Error storing cache: {e}") | |
return False | |
def _add_to_memory_cache(self, cache_key: str, video_path: str): | |
"""Add item to memory cache with LRU eviction""" | |
# Check if we need to evict | |
if len(self._memory_cache) >= self.memory_cache_size: | |
# Find least recently used | |
lru_key = min(self._access_times, key=self._access_times.get) | |
del self._memory_cache[lru_key] | |
del self._access_times[lru_key] | |
self._memory_cache[cache_key] = video_path | |
self._access_times[cache_key] = time.time() | |
def _check_cache_size(self): | |
"""Check and maintain cache size limit""" | |
total_size = sum( | |
entry['size_bytes'] | |
for entry in self.metadata.values() | |
) | |
if total_size > self.file_cache_size_bytes: | |
# Remove oldest entries until under limit | |
sorted_entries = sorted( | |
self.metadata.items(), | |
key=lambda x: x[1]['last_access'] | |
) | |
while total_size > self.file_cache_size_bytes and sorted_entries: | |
key_to_remove, entry = sorted_entries.pop(0) | |
total_size -= entry['size_bytes'] | |
self._remove_cache_entry(key_to_remove) | |
def _cleanup_expired(self): | |
"""Remove expired cache entries""" | |
current_time = time.time() | |
expired_keys = [ | |
key for key, entry in self.metadata.items() | |
if current_time > entry['expires_at'] | |
] | |
for key in expired_keys: | |
self._remove_cache_entry(key) | |
if expired_keys: | |
print(f"Cleaned up {len(expired_keys)} expired cache entries") | |
def _remove_cache_entry(self, cache_key: str): | |
"""Remove a cache entry""" | |
if cache_key in self.metadata: | |
# Remove file | |
video_file = self.cache_dir / self.metadata[cache_key]['filename'] | |
if video_file.exists(): | |
video_file.unlink() | |
# Remove from metadata | |
del self.metadata[cache_key] | |
# Remove from memory cache | |
if cache_key in self._memory_cache: | |
del self._memory_cache[cache_key] | |
del self._access_times[cache_key] | |
def clear_cache(self): | |
"""Clear all cache entries""" | |
# Remove all video files | |
for file in self.cache_dir.glob("*.mp4"): | |
file.unlink() | |
# Clear metadata | |
self.metadata = {} | |
self._save_metadata() | |
# Clear memory cache | |
self._memory_cache.clear() | |
self._access_times.clear() | |
print("Inference cache cleared") | |
def get_cache_stats(self) -> Dict[str, Any]: | |
"""Get cache statistics""" | |
total_size = sum( | |
entry['size_bytes'] | |
for entry in self.metadata.values() | |
) | |
memory_hits = len(self._memory_cache) | |
file_entries = len(self.metadata) | |
return { | |
'memory_cache_entries': memory_hits, | |
'file_cache_entries': file_entries, | |
'total_cache_size_mb': total_size / (1024 * 1024), | |
'cache_size_limit_gb': self.file_cache_size_bytes / (1024 * 1024 * 1024), | |
'ttl_hours': self.ttl_seconds / 3600, | |
'cache_directory': str(self.cache_dir) | |
} | |
class CachedInference: | |
""" | |
Wrapper for cached inference execution | |
""" | |
def __init__(self, cache: InferenceCache): | |
""" | |
Initialize cached inference | |
Args: | |
cache: InferenceCache instance | |
""" | |
self.cache = cache | |
def process_with_cache( | |
self, | |
inference_func: callable, | |
audio_path: str, | |
image_path: str, | |
output_path: str, | |
**kwargs | |
) -> Tuple[str, bool, float]: | |
""" | |
Process with caching | |
Args: | |
inference_func: Function to generate video | |
audio_path: Path to audio file | |
image_path: Path to image file | |
output_path: Desired output path | |
**kwargs: Additional parameters | |
Returns: | |
Tuple of (output_path, cache_hit, process_time) | |
""" | |
start_time = time.time() | |
# Generate cache key | |
cache_key = self.cache.generate_cache_key( | |
audio_path, image_path, **kwargs | |
) | |
# Check cache | |
cached_video = self.cache.get(cache_key) | |
if cached_video: | |
# Cache hit - copy to output path | |
shutil.copy2(cached_video, output_path) | |
process_time = time.time() - start_time | |
print(f"✅ Cache hit! Retrieved in {process_time:.2f}s") | |
return output_path, True, process_time | |
# Cache miss - generate video | |
print("Cache miss - generating video...") | |
inference_func(audio_path, image_path, output_path, **kwargs) | |
# Store in cache | |
if os.path.exists(output_path): | |
self.cache.put(cache_key, output_path, **kwargs) | |
process_time = time.time() - start_time | |
return output_path, False, process_time |