Spaces:
Running
Running
| import os | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import torch | |
| from typing import Optional, Dict, Any | |
| logger = logging.getLogger(__name__) | |
| class HFSpacesModelCache: | |
| """Smart model caching for Hugging Face Spaces with storage optimization""" | |
| def __init__(self): | |
| self.cache_dir = Path("/tmp/hf_models_cache") # Use tmp for ephemeral caching | |
| self.persistent_cache = Path("./model_cache") # Small persistent cache | |
| # Ensure cache directories exist | |
| self.cache_dir.mkdir(exist_ok=True, parents=True) | |
| self.persistent_cache.mkdir(exist_ok=True, parents=True) | |
| # Model configuration with caching strategy | |
| self.models_config = { | |
| "wav2vec2-base-960h": { | |
| "repo_id": "facebook/wav2vec2-base-960h", | |
| "cache_strategy": "download", # Small model, can download | |
| "size_mb": 360, | |
| "essential": True | |
| }, | |
| "text-to-speech": { | |
| "repo_id": "microsoft/speecht5_tts", | |
| "cache_strategy": "download", # For TTS functionality | |
| "size_mb": 500, | |
| "essential": True | |
| } | |
| } | |
| # Large models - use different strategy | |
| self.large_models_config = { | |
| "Wan2.1-T2V-14B": { | |
| "repo_id": "Wan-AI/Wan2.1-T2V-14B", | |
| "cache_strategy": "streaming", # Stream from HF Hub | |
| "size_gb": 28, | |
| "essential": False # Can work without it | |
| }, | |
| "OmniAvatar-14B": { | |
| "repo_id": "OmniAvatar/OmniAvatar-14B", | |
| "cache_strategy": "lazy_load", # Load on demand | |
| "size_gb": 2, | |
| "essential": False | |
| } | |
| } | |
| def setup_smart_caching(self): | |
| """Setup intelligent caching for HF Spaces""" | |
| logger.info("?? Setting up smart model caching for HF Spaces...") | |
| # Download only essential small models | |
| for model_name, config in self.models_config.items(): | |
| if config["essential"] and config["cache_strategy"] == "download": | |
| self._cache_small_model(model_name, config) | |
| # Setup streaming/lazy loading for large models | |
| self._setup_large_model_streaming() | |
| def _cache_small_model(self, model_name: str, config: Dict[str, Any]): | |
| """Cache small essential models locally""" | |
| try: | |
| cache_path = self.persistent_cache / model_name | |
| if cache_path.exists(): | |
| logger.info(f"? {model_name} already cached") | |
| return str(cache_path) | |
| logger.info(f"?? Downloading {model_name} ({config['size_mb']}MB)...") | |
| # Use HF Hub to download to our cache | |
| downloaded_path = snapshot_download( | |
| repo_id=config["repo_id"], | |
| cache_dir=str(cache_path), | |
| local_files_only=False | |
| ) | |
| logger.info(f"? {model_name} cached successfully") | |
| return downloaded_path | |
| except Exception as e: | |
| logger.error(f"? Failed to cache {model_name}: {e}") | |
| return None | |
| def _setup_large_model_streaming(self): | |
| """Setup streaming access for large models""" | |
| logger.info("?? Setting up streaming access for large models...") | |
| # Set environment variables for streaming | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| os.environ["HF_HUB_CACHE"] = str(self.cache_dir) | |
| # Configure streaming parameters | |
| self.streaming_config = { | |
| "use_cache": True, | |
| "low_cpu_mem_usage": True, | |
| "torch_dtype": torch.float16, # Use half precision to save memory | |
| "device_map": "auto" | |
| } | |
| logger.info("? Streaming configuration ready") | |
| def get_model_path_or_stream(self, model_name: str) -> Optional[str]: | |
| """Get model path for local models or streaming config for large models""" | |
| # Check if it's a small cached model | |
| if model_name in self.models_config: | |
| cache_path = self.persistent_cache / model_name | |
| if cache_path.exists(): | |
| return str(cache_path) | |
| # For large models, return the repo_id for streaming | |
| if model_name in self.large_models_config: | |
| config = self.large_models_config[model_name] | |
| logger.info(f"?? {model_name} will be streamed from HF Hub") | |
| return config["repo_id"] # Return repo_id for streaming | |
| return None | |
| def load_model_streaming(self, repo_id: str, **kwargs): | |
| """Load a model with streaming from HF Hub""" | |
| try: | |
| from transformers import AutoModel, AutoProcessor | |
| logger.info(f"?? Streaming model from {repo_id}...") | |
| # Merge streaming config with provided kwargs | |
| load_kwargs = {**self.streaming_config, **kwargs} | |
| # Load model directly from HF Hub (no local storage) | |
| model = AutoModel.from_pretrained( | |
| repo_id, | |
| **load_kwargs | |
| ) | |
| logger.info(f"? Model loaded via streaming") | |
| return model | |
| except Exception as e: | |
| logger.error(f"? Streaming failed for {repo_id}: {e}") | |
| return None | |
| # Global cache manager instance | |
| model_cache_manager = HFSpacesModelCache() | |