| | """ |
| | Model utilities for downloading and managing the marine species model. |
| | """ |
| |
|
| | import os |
| | import shutil |
| | from pathlib import Path |
| | from typing import Optional, Dict, Any |
| | from huggingface_hub import hf_hub_download, list_repo_files |
| |
|
| | from app.core.config import settings |
| | from app.core.logging import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def download_model_from_hf( |
| | repo_id: str, |
| | model_filename: str, |
| | local_dir: str, |
| | force_download: bool = False |
| | ) -> str: |
| | """ |
| | Download model from HuggingFace Hub. |
| | |
| | Args: |
| | repo_id: HuggingFace repository ID |
| | model_filename: Name of the model file |
| | local_dir: Local directory to save the model |
| | force_download: Whether to force re-download if file exists |
| | |
| | Returns: |
| | Path to the downloaded model file |
| | """ |
| | try: |
| | |
| | Path(local_dir).mkdir(parents=True, exist_ok=True) |
| | |
| | local_path = Path(local_dir) / model_filename |
| | |
| | |
| | if local_path.exists() and not force_download: |
| | logger.info(f"Model already exists at {local_path}") |
| | return str(local_path) |
| | |
| | logger.info(f"Downloading {model_filename} from {repo_id}...") |
| | |
| | downloaded_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=model_filename, |
| | local_dir=local_dir, |
| | local_dir_use_symlinks=False, |
| | force_download=force_download |
| | ) |
| | |
| | logger.info(f"Model downloaded successfully to: {downloaded_path}") |
| | return downloaded_path |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to download model: {str(e)}") |
| | raise |
| |
|
| |
|
| | def list_available_files(repo_id: str) -> list: |
| | """ |
| | List all available files in a HuggingFace repository. |
| | |
| | Args: |
| | repo_id: HuggingFace repository ID |
| | |
| | Returns: |
| | List of available files |
| | """ |
| | try: |
| | files = list_repo_files(repo_id) |
| | return files |
| | except Exception as e: |
| | logger.error(f"Failed to list repository files: {str(e)}") |
| | return [] |
| |
|
| |
|
| | def verify_model_file(model_path: str) -> bool: |
| | """ |
| | Verify that a model file exists and is valid. |
| | |
| | Args: |
| | model_path: Path to the model file |
| | |
| | Returns: |
| | True if model file is valid |
| | """ |
| | try: |
| | path = Path(model_path) |
| | |
| | |
| | if not path.exists(): |
| | logger.error(f"Model file does not exist: {model_path}") |
| | return False |
| | |
| | |
| | file_size = path.stat().st_size |
| | if file_size < 1024 * 1024: |
| | logger.warning(f"Model file seems too small: {file_size} bytes") |
| | return False |
| | |
| | |
| | if not path.suffix.lower() in ['.pt', '.pth']: |
| | logger.warning(f"Unexpected model file extension: {path.suffix}") |
| | |
| | logger.info(f"Model file verified: {model_path} ({file_size / (1024*1024):.1f} MB)") |
| | return True |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to verify model file: {str(e)}") |
| | return False |
| |
|
| |
|
| | def get_model_info(model_path: str) -> Dict[str, Any]: |
| | """ |
| | Get information about a model file. |
| | |
| | Args: |
| | model_path: Path to the model file |
| | |
| | Returns: |
| | Dictionary with model information |
| | """ |
| | info = { |
| | "path": model_path, |
| | "exists": False, |
| | "size_mb": 0, |
| | "size_bytes": 0 |
| | } |
| | |
| | try: |
| | path = Path(model_path) |
| | |
| | if path.exists(): |
| | info["exists"] = True |
| | size_bytes = path.stat().st_size |
| | info["size_bytes"] = size_bytes |
| | info["size_mb"] = size_bytes / (1024 * 1024) |
| | info["modified_time"] = path.stat().st_mtime |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to get model info: {str(e)}") |
| | |
| | return info |
| |
|
| |
|
| | def cleanup_model_cache(cache_dir: Optional[str] = None) -> None: |
| | """ |
| | Clean up model cache directory. |
| | |
| | Args: |
| | cache_dir: Cache directory to clean (uses default if None) |
| | """ |
| | try: |
| | if cache_dir is None: |
| | cache_dir = Path.home() / ".cache" / "huggingface" |
| | |
| | cache_path = Path(cache_dir) |
| | |
| | if cache_path.exists(): |
| | logger.info(f"Cleaning up cache directory: {cache_path}") |
| | shutil.rmtree(cache_path) |
| | logger.info("Cache cleanup completed") |
| | else: |
| | logger.info("Cache directory does not exist") |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to cleanup cache: {str(e)}") |
| |
|
| |
|
| | def setup_model_directory() -> str: |
| | """ |
| | Setup the model directory and ensure it exists. |
| | |
| | Returns: |
| | Path to the model directory |
| | """ |
| | model_dir = Path(settings.MODEL_PATH).parent |
| | model_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | logger.info(f"Model directory setup: {model_dir}") |
| | return str(model_dir) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Model management utility") |
| | parser.add_argument("--download", action="store_true", help="Download model from HuggingFace") |
| | parser.add_argument("--verify", action="store_true", help="Verify model file") |
| | parser.add_argument("--info", action="store_true", help="Show model information") |
| | parser.add_argument("--list-files", action="store_true", help="List available files in HF repo") |
| | parser.add_argument("--cleanup-cache", action="store_true", help="Cleanup model cache") |
| | parser.add_argument("--force", action="store_true", help="Force download even if file exists") |
| | |
| | args = parser.parse_args() |
| | |
| | if args.download: |
| | setup_model_directory() |
| | download_model_from_hf( |
| | repo_id=settings.HUGGINGFACE_REPO, |
| | model_filename=f"{settings.MODEL_NAME}.pt", |
| | local_dir=str(Path(settings.MODEL_PATH).parent), |
| | force_download=args.force |
| | ) |
| | |
| | if args.verify: |
| | is_valid = verify_model_file(settings.MODEL_PATH) |
| | print(f"Model valid: {is_valid}") |
| | |
| | if args.info: |
| | info = get_model_info(settings.MODEL_PATH) |
| | print(f"Model info: {info}") |
| | |
| | if args.list_files: |
| | files = list_available_files(settings.HUGGINGFACE_REPO) |
| | print(f"Available files in {settings.HUGGINGFACE_REPO}:") |
| | for file in files: |
| | print(f" - {file}") |
| | |
| | if args.cleanup_cache: |
| | cleanup_model_cache() |
| |
|