| """ |
| Hugging Face Hub service for downloading model repositories. |
| """ |
|
|
| import os |
| from pathlib import Path |
| from typing import Optional |
|
|
| from huggingface_hub import snapshot_download |
| from huggingface_hub.utils import HfHubHTTPError |
|
|
| from app.core.config import settings |
| from app.core.errors import HuggingFaceDownloadError |
| from app.core.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
|
|
|
|
| class HFHubService: |
| """ |
| Service for interacting with Hugging Face Hub. |
| |
| Handles downloading model repositories and caching them locally. |
| """ |
| |
| def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None): |
| """ |
| Initialize the HF Hub service. |
| |
| Args: |
| cache_dir: Local directory for caching downloads. |
| Defaults to settings.HF_CACHE_DIR |
| token: Hugging Face API token for private repos. |
| Defaults to settings.HF_TOKEN |
| """ |
| self.cache_dir = cache_dir or settings.HF_CACHE_DIR |
| self.token = token or settings.HF_TOKEN |
| |
| |
| Path(self.cache_dir).mkdir(parents=True, exist_ok=True) |
| logger.info(f"HF Hub service initialized with cache dir: {self.cache_dir}") |
| |
| def download_repo( |
| self, |
| repo_id: str, |
| revision: Optional[str] = None, |
| force_download: bool = False |
| ) -> str: |
| """ |
| Download a repository from Hugging Face Hub. |
| |
| Uses snapshot_download which handles caching automatically. |
| If the repo is already cached and not stale, it returns the cached path. |
| |
| Args: |
| repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a") |
| revision: Git revision (branch, tag, or commit hash). Defaults to "main" |
| force_download: If True, re-download even if cached |
| |
| Returns: |
| Local path to the downloaded repository |
| |
| Raises: |
| HuggingFaceDownloadError: If download fails |
| """ |
| logger.info(f"Downloading repo: {repo_id} (revision={revision}, force={force_download})") |
| |
| try: |
| |
| repo_name = repo_id.replace("/", "--") |
| local_dir = Path(self.cache_dir) / repo_name |
| |
| local_path = snapshot_download( |
| repo_id=repo_id, |
| revision=revision or "main", |
| local_dir=str(local_dir), |
| token=self.token, |
| force_download=force_download, |
| local_files_only=False |
| ) |
| |
| logger.info(f"Downloaded {repo_id} to {local_path}") |
| return local_path |
| |
| except HfHubHTTPError as e: |
| logger.error(f"HTTP error downloading {repo_id}: {e}") |
| raise HuggingFaceDownloadError( |
| message=f"Failed to download repository: {repo_id}", |
| details={"repo_id": repo_id, "error": str(e)} |
| ) |
| except Exception as e: |
| logger.error(f"Error downloading {repo_id}: {e}") |
| raise HuggingFaceDownloadError( |
| message=f"Failed to download repository: {repo_id}", |
| details={"repo_id": repo_id, "error": str(e)} |
| ) |
| |
| def get_cached_path(self, repo_id: str) -> Optional[str]: |
| """ |
| Get the cached path for a repository if it exists. |
| |
| Args: |
| repo_id: Hugging Face repository ID |
| |
| Returns: |
| Local path if cached, None otherwise |
| """ |
| |
| repo_name = repo_id.replace("/", "--") |
| local_dir = Path(self.cache_dir) / repo_name |
| |
| if local_dir.exists() and any(local_dir.iterdir()): |
| return str(local_dir) |
| return None |
| |
| def is_cached(self, repo_id: str) -> bool: |
| """ |
| Check if a repository is already cached. |
| |
| Args: |
| repo_id: Hugging Face repository ID |
| |
| Returns: |
| True if cached, False otherwise |
| """ |
| return self.get_cached_path(repo_id) is not None |
|
|
|
|
| |
| _hf_hub_service: Optional[HFHubService] = None |
|
|
|
|
| def get_hf_hub_service() -> HFHubService: |
| """ |
| Get the global HF Hub service instance. |
| |
| Returns: |
| HFHubService instance |
| """ |
| global _hf_hub_service |
| if _hf_hub_service is None: |
| _hf_hub_service = HFHubService() |
| return _hf_hub_service |
|
|