File size: 2,501 Bytes
79acde6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27e0318
79acde6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84fb7e3
79acde6
 
 
 
 
84fb7e3
79acde6
 
 
84fb7e3
79acde6
 
 
 
 
 
 
 
84fb7e3
79acde6
84fb7e3
79acde6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import logging
from huggingface_hub import hf_hub_download

# Configure logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Supported Hugging Face models (key → repo + file)
HF_MODELS = {
    # Depth Estimation
    "dpt_hybrid_384": ("isl-org/MiDaS", "dpt_hybrid_384.pt"),
    "midas_v21_small_256": ("isl-org/MiDaS", "midas_v21_small_256.pt"),
    "midas_v21_384": ("isl-org/MiDaS", "midas_v21_384.pt"),
    "dpt_swin2_large_384": ("isl-org/MiDaS", "dpt_swin2_large_384.pt"),
    "dpt_beit_large_512": ("isl-org/MiDaS", "dpt_beit_large_512.pt"),

    # Object Detection
    "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
    "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
    "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
    "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
    "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth"),

    # Semantic Segmentation
    "segformer_b0": ("nvidia/segformer-b0-finetuned-ade-512-512", "model.safetensors"),
    "segformer_b5": ("nvidia/segformer-b5-finetuned-ade-512-512", "model.safetensors"),
    "deeplabv3_resnet50": ("facebook/deeplabv3-resnet50", "pytorch_model.bin"),
}


def download_model_if_needed(model_key: str, save_path: str):
    """
    Downloads the model from Hugging Face Hub if it's not already present.

    Args:
        model_key (str): Key from HF_MODELS dict.
        save_path (str): Local path to store the downloaded model.

    Raises:
        ValueError: If the model_key is not supported.
    """
    if model_key not in HF_MODELS:
        logger.error(f" Model key '{model_key}' not found in registry.")
        raise ValueError(f"Unsupported model key: {model_key}")

    repo_id, filename = HF_MODELS[model_key]

    if os.path.exists(save_path):
        logger.info(f" Model '{model_key}' already exists at '{save_path}'. Skipping download.")
        return

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    logger.info(f" Downloading '{model_key}' from Hugging Face Hub...")

    try:
        hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            cache_dir=os.path.dirname(save_path),
            force_download=True  # Set to False later if you want to cache
        )
        logger.info(f" Successfully downloaded '{model_key}' to '{save_path}'")
    except Exception as e:
        logger.error(f" Download failed for '{model_key}': {e}")
        raise