|
|
""" |
|
|
Lightweight model setup utility for dwpose-editor. |
|
|
|
|
|
Provides a CLI to pre-download required ONNX models into the local cache |
|
|
directory used by the app (./models), without initializing ONNX sessions. |
|
|
""" |
|
|
|
|
|
from typing import List |
|
|
import os |
|
|
import sys |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
except Exception as e: |
|
|
hf_hub_download = None |
|
|
|
|
|
|
|
|
DEFAULT_REPO_ID = "yzd-v/DWPose" |
|
|
DEFAULT_CACHE_DIR = "./models" |
|
|
REQUIRED_FILES: List[str] = [ |
|
|
"yolox_l.onnx", |
|
|
"dw-ll_ucoco_384.onnx", |
|
|
] |
|
|
|
|
|
|
|
|
def ensure_dir(path: str) -> None: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
|
|
|
def download_models(repo_id: str = DEFAULT_REPO_ID, cache_dir: str = DEFAULT_CACHE_DIR) -> int: |
|
|
"""Download required model files into cache_dir. Returns 0 on success, non-zero on failure.""" |
|
|
if hf_hub_download is None: |
|
|
print("[ERROR] huggingface-hub is not installed. Run: pip install huggingface-hub") |
|
|
return 2 |
|
|
|
|
|
ensure_dir(cache_dir) |
|
|
|
|
|
ok = True |
|
|
for filename in REQUIRED_FILES: |
|
|
try: |
|
|
print(f"[SETUP] Downloading {filename} from {repo_id} β {cache_dir} ...") |
|
|
local_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) |
|
|
print(f"[SETUP] OK: {local_path}") |
|
|
except Exception as e: |
|
|
ok = False |
|
|
print(f"[SETUP] ERROR: failed to download {filename}: {e}") |
|
|
|
|
|
return 0 if ok else 1 |
|
|
|
|
|
|
|
|
def main(argv: List[str] | None = None) -> int: |
|
|
"""CLI entry point: python -m utils.model_setup [--repo REPO] [--cache-dir DIR]""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Pre-download DWPose models into ./models") |
|
|
parser.add_argument("--repo", default=DEFAULT_REPO_ID, help="Hugging Face repo id") |
|
|
parser.add_argument("--cache-dir", default=DEFAULT_CACHE_DIR, help="Local cache directory") |
|
|
args = parser.parse_args(argv) |
|
|
|
|
|
return download_models(repo_id=args.repo, cache_dir=args.cache_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|
|
|
|