|
""" |
|
Cache Management and SAM2 Loading Utilities |
|
Comprehensive cache cleaning system to resolve model loading issues on HF Spaces |
|
""" |
|
|
|
import os |
|
import gc |
|
import sys |
|
import shutil |
|
import tempfile |
|
import logging |
|
import traceback |
|
from pathlib import Path |
|
from typing import Optional, Dict, Any, Tuple |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class HardCacheCleaner: |
|
""" |
|
Comprehensive cache cleaning system to resolve SAM2 loading issues |
|
Clears Python module cache, HuggingFace cache, and temp files |
|
""" |
|
|
|
@staticmethod |
|
def clean_all_caches(verbose: bool = True): |
|
"""Clean all caches that might interfere with SAM2 loading""" |
|
|
|
if verbose: |
|
logger.info("Starting comprehensive cache cleanup...") |
|
|
|
|
|
HardCacheCleaner._clean_python_cache(verbose) |
|
|
|
|
|
HardCacheCleaner._clean_huggingface_cache(verbose) |
|
|
|
|
|
HardCacheCleaner._clean_pytorch_cache(verbose) |
|
|
|
|
|
HardCacheCleaner._clean_temp_directories(verbose) |
|
|
|
|
|
HardCacheCleaner._clear_import_cache(verbose) |
|
|
|
|
|
HardCacheCleaner._force_gc_cleanup(verbose) |
|
|
|
if verbose: |
|
logger.info("Cache cleanup completed") |
|
|
|
@staticmethod |
|
def _clean_python_cache(verbose: bool = True): |
|
"""Clean Python bytecode cache""" |
|
try: |
|
|
|
sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()] |
|
for module in sam2_modules: |
|
if verbose: |
|
logger.info(f"Removing cached module: {module}") |
|
del sys.modules[module] |
|
|
|
|
|
for root, dirs, files in os.walk("."): |
|
for dir_name in dirs[:]: |
|
if dir_name == "__pycache__": |
|
cache_path = os.path.join(root, dir_name) |
|
if verbose: |
|
logger.info(f"Removing __pycache__: {cache_path}") |
|
shutil.rmtree(cache_path, ignore_errors=True) |
|
dirs.remove(dir_name) |
|
|
|
except Exception as e: |
|
logger.warning(f"Python cache cleanup failed: {e}") |
|
|
|
@staticmethod |
|
def _clean_huggingface_cache(verbose: bool = True): |
|
"""Clean HuggingFace model cache""" |
|
try: |
|
|
|
from config.app_config import get_config |
|
config = get_config() |
|
|
|
cache_paths = [ |
|
os.path.expanduser("~/.cache/huggingface/"), |
|
os.path.expanduser("~/.cache/torch/"), |
|
config.model_cache_dir, |
|
"./checkpoints/", |
|
"./.cache/", |
|
] |
|
|
|
for cache_path in cache_paths: |
|
if os.path.exists(cache_path): |
|
if verbose: |
|
logger.info(f"Cleaning cache directory: {cache_path}") |
|
|
|
|
|
for root, dirs, files in os.walk(cache_path): |
|
for file in files: |
|
if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']): |
|
file_path = os.path.join(root, file) |
|
try: |
|
os.remove(file_path) |
|
if verbose: |
|
logger.info(f"Removed cached file: {file_path}") |
|
except: |
|
pass |
|
|
|
for dir_name in dirs[:]: |
|
if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']): |
|
dir_path = os.path.join(root, dir_name) |
|
try: |
|
shutil.rmtree(dir_path, ignore_errors=True) |
|
if verbose: |
|
logger.info(f"Removed cached directory: {dir_path}") |
|
dirs.remove(dir_name) |
|
except: |
|
pass |
|
|
|
except Exception as e: |
|
logger.warning(f"HuggingFace cache cleanup failed: {e}") |
|
|
|
@staticmethod |
|
def _clean_pytorch_cache(verbose: bool = True): |
|
"""Clean PyTorch cache""" |
|
try: |
|
import torch |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
if verbose: |
|
logger.info("Cleared PyTorch CUDA cache") |
|
except Exception as e: |
|
logger.warning(f"PyTorch cache cleanup failed: {e}") |
|
|
|
@staticmethod |
|
def _clean_temp_directories(verbose: bool = True): |
|
"""Clean temporary directories""" |
|
try: |
|
from config.app_config import get_config |
|
config = get_config() |
|
|
|
temp_dirs = [ |
|
config.temp_dir, |
|
tempfile.gettempdir(), |
|
"/tmp", |
|
"./tmp", |
|
"./temp" |
|
] |
|
|
|
for temp_dir in temp_dirs: |
|
if os.path.exists(temp_dir): |
|
for item in os.listdir(temp_dir): |
|
if 'sam2' in item.lower() or 'segment' in item.lower(): |
|
item_path = os.path.join(temp_dir, item) |
|
try: |
|
if os.path.isfile(item_path): |
|
os.remove(item_path) |
|
elif os.path.isdir(item_path): |
|
shutil.rmtree(item_path, ignore_errors=True) |
|
if verbose: |
|
logger.info(f"Removed temp item: {item_path}") |
|
except: |
|
pass |
|
|
|
except Exception as e: |
|
logger.warning(f"Temp directory cleanup failed: {e}") |
|
|
|
@staticmethod |
|
def _clear_import_cache(verbose: bool = True): |
|
"""Clear Python import cache""" |
|
try: |
|
import importlib |
|
|
|
|
|
importlib.invalidate_caches() |
|
|
|
if verbose: |
|
logger.info("Cleared Python import cache") |
|
|
|
except Exception as e: |
|
logger.warning(f"Import cache cleanup failed: {e}") |
|
|
|
@staticmethod |
|
def _force_gc_cleanup(verbose: bool = True): |
|
"""Force garbage collection""" |
|
try: |
|
collected = gc.collect() |
|
if verbose: |
|
logger.info(f"Garbage collection freed {collected} objects") |
|
except Exception as e: |
|
logger.warning(f"Garbage collection failed: {e}") |
|
|
|
|
|
class WorkingSAM2Loader: |
|
""" |
|
SAM2 loader using HuggingFace Transformers integration - proven to work on HF Spaces |
|
This avoids all the config file and CUDA compilation issues |
|
""" |
|
|
|
@staticmethod |
|
def load_sam2_transformers_approach(device: str = "cuda", model_size: str = "large") -> Optional[Any]: |
|
""" |
|
Load SAM2 using HuggingFace Transformers integration |
|
This method works reliably on HuggingFace Spaces |
|
""" |
|
try: |
|
logger.info("Loading SAM2 via HuggingFace Transformers...") |
|
|
|
|
|
model_map = { |
|
"tiny": "facebook/sam2.1-hiera-tiny", |
|
"small": "facebook/sam2.1-hiera-small", |
|
"base": "facebook/sam2.1-hiera-base-plus", |
|
"large": "facebook/sam2.1-hiera-large" |
|
} |
|
|
|
model_id = model_map.get(model_size, model_map["large"]) |
|
logger.info(f"Using model: {model_id}") |
|
|
|
|
|
try: |
|
from transformers import pipeline |
|
|
|
sam2_pipeline = pipeline( |
|
"mask-generation", |
|
model=model_id, |
|
device=0 if device == "cuda" else -1 |
|
) |
|
|
|
logger.info("SAM2 loaded successfully via Transformers pipeline") |
|
return sam2_pipeline |
|
|
|
except Exception as e: |
|
logger.warning(f"Pipeline approach failed: {e}") |
|
|
|
|
|
try: |
|
from transformers import Sam2Processor, Sam2Model |
|
|
|
processor = Sam2Processor.from_pretrained(model_id) |
|
model = Sam2Model.from_pretrained(model_id).to(device) |
|
|
|
logger.info("SAM2 loaded successfully via Transformers classes") |
|
return {"model": model, "processor": processor} |
|
|
|
except Exception as e: |
|
logger.warning(f"Direct class approach failed: {e}") |
|
|
|
|
|
try: |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
predictor = SAM2ImagePredictor.from_pretrained(model_id) |
|
|
|
logger.info("SAM2 loaded successfully via official from_pretrained") |
|
return predictor |
|
|
|
except Exception as e: |
|
logger.warning(f"Official from_pretrained approach failed: {e}") |
|
|
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"All SAM2 loading methods failed: {e}") |
|
return None |
|
|
|
@staticmethod |
|
def load_sam2_fallback_approach(device: str = "cuda") -> Optional[Any]: |
|
""" |
|
Fallback approach using direct model loading |
|
""" |
|
try: |
|
logger.info("Trying fallback SAM2 loading approach...") |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id="facebook/sam2.1-hiera-large", |
|
filename="sam2_hiera_large.pt" |
|
) |
|
|
|
logger.info(f"Downloaded checkpoint to: {checkpoint_path}") |
|
|
|
|
|
try: |
|
|
|
from transformers import Sam2Model |
|
model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large") |
|
return model.to(device) |
|
|
|
except Exception as e: |
|
logger.warning(f"Transformers fallback failed: {e}") |
|
|
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"Fallback loading failed: {e}") |
|
return None |
|
|
|
|
|
def load_sam2_with_cache_cleanup( |
|
device: str = "cuda", |
|
model_size: str = "large", |
|
force_cache_clean: bool = True, |
|
verbose: bool = True |
|
) -> Tuple[Optional[Any], str]: |
|
""" |
|
Load SAM2 with comprehensive cache cleanup |
|
|
|
Returns: |
|
Tuple of (model, status_message) |
|
""" |
|
|
|
status_messages = [] |
|
|
|
try: |
|
|
|
if force_cache_clean: |
|
status_messages.append("Cleaning caches...") |
|
HardCacheCleaner.clean_all_caches(verbose=verbose) |
|
status_messages.append("Cache cleanup completed") |
|
|
|
|
|
status_messages.append("Loading SAM2 (primary method)...") |
|
model = WorkingSAM2Loader.load_sam2_transformers_approach(device, model_size) |
|
|
|
if model is not None: |
|
status_messages.append("SAM2 loaded successfully!") |
|
return model, "\n".join(status_messages) |
|
|
|
|
|
status_messages.append("Trying fallback loading method...") |
|
model = WorkingSAM2Loader.load_sam2_fallback_approach(device) |
|
|
|
if model is not None: |
|
status_messages.append("SAM2 loaded successfully (fallback)!") |
|
return model, "\n".join(status_messages) |
|
|
|
|
|
status_messages.append("All SAM2 loading methods failed") |
|
return None, "\n".join(status_messages) |
|
|
|
except Exception as e: |
|
error_msg = f"Critical error in SAM2 loading: {e}" |
|
logger.error(f"{error_msg}\n{traceback.format_exc()}") |
|
status_messages.append(error_msg) |
|
return None, "\n".join(status_messages) |