aiBatteryLifeCycle / scripts /download_models.py
NeerajCodz's picture
fix: dropdown and frontend
7996ada
"""
Download model artifacts from Hugging Face Hub at container startup.
Called automatically by the Docker entrypoint before uvicorn starts.
Can also download a specific version on-demand (e.g. from the API).
HF model repo layout (v1/ and v2/ at repo root):
v1/models/classical/*.joblib
v1/models/deep/*.pt *.keras
v1/scalers/*.joblib
v2/models/classical/*.joblib
v2/models/deep/*.pt *.keras
v2/scalers/*.joblib
v2/results/*.json
Local layout after download (local_dir = ARTIFACTS_DIR):
artifacts/v1/...
artifacts/v2/...
"""
import os
import sys
import json
from pathlib import Path
# ──────────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────────
REPO_ID = "NeerajCodz/aiBatteryLifeCycle"
REPO_TYPE = "model"
# Token read from the HF_TOKEN Space Secret (set in Space Settings -> Secrets)
# For local use: set HF_TOKEN in your shell or .env before running
HF_TOKEN = os.getenv("HF_TOKEN", "")
# HF repo stores v1/ and v2/ at root β†’ local_dir=ARTIFACTS_DIR maps them to
# artifacts/v1/... and artifacts/v2/...
ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts"
DEFAULT_STARTUP_TOP_MODELS = 3
# Sentinel file β€” written after a successful full download
SENTINEL = ARTIFACTS_DIR / ".hf_downloaded"
# ──────────────────────────────────────────────────────────────────────────────
def _hf_kwargs(allow_patterns: list | None = None,
ignore_patterns: list | None = None) -> dict:
"""Build kwargs for snapshot_download; inject token only when non-empty."""
kwargs: dict = dict(
repo_id=REPO_ID,
repo_type=REPO_TYPE,
local_dir=str(ARTIFACTS_DIR),
)
if allow_patterns:
kwargs["allow_patterns"] = allow_patterns
if ignore_patterns:
kwargs["ignore_patterns"] = ignore_patterns
if HF_TOKEN:
kwargs["token"] = HF_TOKEN
return kwargs
def _key_models(version: str = "v3") -> list:
base = ARTIFACTS_DIR / version / "models" / "classical"
return [base / f"{m}.joblib" for m in ("random_forest", "xgboost", "lightgbm")]
def version_loaded(version: str) -> bool:
"""Return True when the given version's key models exist on disk."""
return all(p.exists() for p in _key_models(version))
def already_downloaded(version: str = "v3") -> bool:
"""Return True only when all three BestEnsemble component models are present."""
missing = [p for p in _key_models(version) if not p.exists()]
if missing:
if SENTINEL.exists():
SENTINEL.unlink()
print(f"[download_models] Sentinel stale ({len(missing)} key models missing) β€” will re-download")
return False
return True
def _ensure_hub():
try:
from huggingface_hub import snapshot_download # noqa: F401
except ImportError:
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install",
"huggingface_hub>=0.23", "-q"])
def write_datamap(version: str) -> None:
"""Write artifacts/<version>/datamap.json listing all locally available files."""
vroot = ARTIFACTS_DIR / version
vroot.mkdir(parents=True, exist_ok=True)
items = []
for p in sorted(vroot.rglob("*")):
if not p.is_file():
continue
rel = p.relative_to(vroot).as_posix()
items.append({
"path": rel,
"bytes": p.stat().st_size,
})
out = {
"version": version,
"count": len(items),
"files": items,
}
(vroot / "datamap.json").write_text(json.dumps(out, indent=2), encoding="utf-8")
def download_version(version: str) -> None:
"""Download a single version (e.g. 'v1' or 'v2') from HF Hub into artifacts/."""
_ensure_hub()
from huggingface_hub import snapshot_download
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
print(f"[download_models] Downloading {version}/ from {REPO_ID} -> {ARTIFACTS_DIR}")
snapshot_download(**_hf_kwargs(
allow_patterns=[f"{version}/**"],
ignore_patterns=["*.log"],
))
write_datamap(version)
print(f"[download_models] {version}/ ready")
def _read_models_meta(version: str) -> dict:
meta_path = ARTIFACTS_DIR / version / "models.json"
if not meta_path.exists():
return {}
try:
return json.loads(meta_path.read_text(encoding="utf-8"))
except Exception:
return {}
def _model_score(info: dict) -> float:
"""Return a comparable score for a model metadata record.
Priority:
1) R2 (higher is better)
2) f1_weighted / f1_macro (higher is better)
3) within_5pct (higher is better)
4) mae (lower is better; converted to negative contribution)
"""
if not isinstance(info, dict):
return float("-inf")
r2 = info.get("r2")
if isinstance(r2, (int, float)):
return float(r2)
# Fallback when r2 is unavailable.
f1w = info.get("f1_weighted")
if isinstance(f1w, (int, float)):
return float(f1w)
f1m = info.get("f1_macro")
if isinstance(f1m, (int, float)):
return float(f1m)
w5 = info.get("within_5pct")
if isinstance(w5, (int, float)):
return float(w5) / 100.0
mae = info.get("mae")
if isinstance(mae, (int, float)):
return -float(mae)
return float("-inf")
def select_top_models(version: str, top_k: int = DEFAULT_STARTUP_TOP_MODELS) -> list[str]:
"""Pick top-k concrete models from models.json by best available metric score.
Excludes virtual entries without a model file.
"""
meta = _read_models_meta(version)
models = meta.get("models", {}) if isinstance(meta, dict) else {}
scored: list[tuple[float, str]] = []
for name, info in models.items():
if not isinstance(info, dict):
continue
if not info.get("file"):
continue
score = _model_score(info)
if score != float("-inf"):
scored.append((score, name))
scored.sort(key=lambda x: x[0], reverse=True)
return [name for _, name in scored[:top_k]]
def _collect_model_patterns(version: str, model_names: list[str]) -> list[str]:
"""Build allow_patterns for model files + required shared artifacts."""
meta = _read_models_meta(version)
models = meta.get("models", {}) if isinstance(meta, dict) else {}
patterns = {
f"{version}/models.json",
f"{version}/features/train_split.csv",
f"{version}/scalers/features_standard.joblib",
f"{version}/scalers/sequence_scaler.joblib",
f"{version}/scalers/features_minmax.joblib",
}
for model_name in model_names:
info = models.get(model_name, {})
rel_file = info.get("file") if isinstance(info, dict) else None
if rel_file:
patterns.add(f"{version}/{rel_file}")
return sorted(patterns)
def download_models(version: str, model_names: list[str]) -> None:
"""Download selected model artifacts for one version."""
_ensure_hub()
from huggingface_hub import snapshot_download
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
allow = _collect_model_patterns(version, model_names)
print(f"[download_models] Downloading selected models for {version}: {model_names}")
snapshot_download(**_hf_kwargs(
allow_patterns=allow,
ignore_patterns=["*.log"],
))
write_datamap(version)
print(f"[download_models] Selected model artifacts ready for {version}")
def download_model(version: str, model_name: str) -> None:
"""Download one model artifact (plus required metadata/scalers)."""
download_models(version, [model_name])
def download_metrics_bundle(version: str) -> None:
"""Download files needed by metrics page for a specific version."""
_ensure_hub()
from huggingface_hub import snapshot_download
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
allow = [
f"{version}/models.json",
f"{version}/datamap.json",
f"{version}/results/**",
f"{version}/reports/**",
f"{version}/features/**",
f"{version}/figures/**",
]
print(f"[download_models] Downloading metrics bundle for {version}")
snapshot_download(**_hf_kwargs(
allow_patterns=allow,
ignore_patterns=["*.log"],
))
write_datamap(version)
print(f"[download_models] Metrics bundle ready for {version}")
def download_models_meta_only(versions: list[str]) -> None:
"""Download only models.json for listed versions."""
_ensure_hub()
from huggingface_hub import snapshot_download
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
allow = [f"{v}/models.json" for v in versions] + [f"{v}/datamap.json" for v in versions]
print(f"[download_models] Downloading metadata files: {allow}")
snapshot_download(**_hf_kwargs(
allow_patterns=allow,
ignore_patterns=["*.log"],
))
for v in versions:
write_datamap(v)
print("[download_models] Metadata ready")
def download_all() -> None:
"""Download all versions (v1 + v2 + v3) from HF Hub."""
_ensure_hub()
from huggingface_hub import snapshot_download
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
print(f"[download_models] Downloading all versions from {REPO_ID} -> {ARTIFACTS_DIR}")
snapshot_download(**_hf_kwargs(ignore_patterns=["*.log"]))
for v in ("v1", "v2", "v3"):
write_datamap(v)
SENTINEL.write_text("downloaded\n")
print("[download_models] Artifacts ready")
def ensure_metadata_first(versions: list[str] | None = None) -> None:
"""Guarantee models.json/datamap are present before registry usage."""
versions = versions or ["v1", "v2", "v3"]
download_models_meta_only(versions)
def main() -> None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--version", default=None,
help="Download only this version, e.g. v1 or v2")
parser.add_argument("--model", default=None,
help="Download only this model key inside --version")
args = parser.parse_args()
if args.version and args.model:
download_models_meta_only([args.version])
download_model(args.version, args.model)
return
if args.version:
if version_loaded(args.version):
print(f"[download_models] {args.version} already present β€” skipping")
else:
download_version(args.version)
return
# Default: ensure models.json exists for all versions, but download only
# latest (v3) heavy model artifacts at startup.
download_models_meta_only(["v1", "v2", "v3"])
# Then fetch only top-N (default=3) v3 model artifacts by best score.
top_models = select_top_models("v3")
if not top_models:
print("[download_models] Could not resolve top models from models.json, falling back to full v3 download")
if already_downloaded("v3"):
print("[download_models] v3 artifacts already present β€” skipping download")
return
download_version("v3")
return
print(f"[download_models] Default startup model set (v3 top-{DEFAULT_STARTUP_TOP_MODELS}): {top_models}")
download_models("v3", top_models)
if __name__ == "__main__":
main()