Spaces:
Sleeping
Sleeping
Commit ·
7996ada
1
Parent(s): 04d655f
fix: dropdown and frontend
Browse files- api/main.py +50 -10
- api/model_registry.py +31 -1
- api/routers/simulate.py +23 -10
- api/routers/visualize.py +55 -0
- frontend/src/App.tsx +1 -1
- frontend/src/components/MetricsPanel.tsx +9 -1
- frontend/src/components/RecommendationPanel.tsx +75 -13
- frontend/src/components/ResearchPaper.tsx +23 -19
- frontend/src/components/VersionSelector.tsx +13 -95
- scripts/download_models.py +56 -1
api/main.py
CHANGED
|
@@ -53,7 +53,12 @@ from fastapi.responses import FileResponse
|
|
| 53 |
|
| 54 |
from api.model_registry import registry, registry_v1, registry_v2, registry_v3
|
| 55 |
from api.schemas import HealthResponse
|
| 56 |
-
from scripts.download_models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
from src.utils.logger import get_logger
|
| 58 |
|
| 59 |
log = get_logger(__name__)
|
|
@@ -70,6 +75,11 @@ _FRONTEND_DIST = _HERE.parent / "frontend" / "dist"
|
|
| 70 |
async def lifespan(app: FastAPI):
|
| 71 |
"""Start API immediately; bootstrap top v3 models in background."""
|
| 72 |
log.info("Loading model registries …")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
_version_status["v3"] = "downloading"
|
| 74 |
app.state.v3_bootstrap_task = asyncio.create_task(_bg_bootstrap_v3())
|
| 75 |
log.info("v3 bootstrap started in background — API is available immediately")
|
|
@@ -147,6 +157,7 @@ async def list_versions():
|
|
| 147 |
out = []
|
| 148 |
for v in ["v3", "v2", "v1"]:
|
| 149 |
reg = _REGISTRIES[v]
|
|
|
|
| 150 |
on_disk = _version_loaded(v)
|
| 151 |
in_memory = reg.model_count > 0
|
| 152 |
meta = reg._version_meta # from models.json (loaded in __init__)
|
|
@@ -177,6 +188,7 @@ async def _bg_load_version(version: str) -> None:
|
|
| 177 |
)
|
| 178 |
await proc.wait()
|
| 179 |
if proc.returncode == 0:
|
|
|
|
| 180 |
_REGISTRIES[version].load_all()
|
| 181 |
_version_status[version] = "ready"
|
| 182 |
log.info("Version %s loaded on demand — %d models", version,
|
|
@@ -208,6 +220,7 @@ async def _bg_bootstrap_v3() -> None:
|
|
| 208 |
|
| 209 |
startup_models = select_top_models("v3")
|
| 210 |
if startup_models:
|
|
|
|
| 211 |
registry_v3.load_all(only_models=set(startup_models))
|
| 212 |
log.info(
|
| 213 |
"v3 registry ready — %d models loaded (startup top-%d set: %s)",
|
|
@@ -228,7 +241,7 @@ async def _bg_bootstrap_v3() -> None:
|
|
| 228 |
log.error("Failed during v3 bootstrap: %s", exc)
|
| 229 |
|
| 230 |
|
| 231 |
-
async def
|
| 232 |
import sys as _sys
|
| 233 |
|
| 234 |
key = _model_status_key(version, model_name)
|
|
@@ -244,15 +257,16 @@ async def _bg_load_model(version: str, model_name: str) -> None:
|
|
| 244 |
stderr=asyncio.subprocess.STDOUT,
|
| 245 |
)
|
| 246 |
await proc.wait()
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
| 250 |
else:
|
| 251 |
_model_status[key] = "error"
|
| 252 |
-
log.error("Failed to download
|
| 253 |
except Exception as exc:
|
| 254 |
_model_status[key] = "error"
|
| 255 |
-
log.error("Failed to
|
| 256 |
|
| 257 |
|
| 258 |
@app.post("/api/versions/{version}/load", tags=["meta"])
|
|
@@ -260,6 +274,8 @@ async def load_version(version: str, background_tasks: BackgroundTasks):
|
|
| 260 |
"""Download + activate a model version from HF Hub (runs in background)."""
|
| 261 |
if version not in _REGISTRIES:
|
| 262 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
|
|
|
|
|
|
| 263 |
if _version_status.get(version) == "downloading":
|
| 264 |
return {"status": "downloading", "version": version}
|
| 265 |
# If artifacts exist on disk but not loaded, just load without downloading
|
|
@@ -282,10 +298,30 @@ async def get_version_models_meta(version: str):
|
|
| 282 |
"""Return models.json metadata for a specific version."""
|
| 283 |
if version not in _REGISTRIES:
|
| 284 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
|
|
|
|
|
|
| 285 |
meta_path = _artifacts_dir() / version / "models.json"
|
| 286 |
if not meta_path.exists():
|
| 287 |
raise HTTPException(status_code=404, detail=f"models.json not found for {version}")
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
|
| 291 |
@app.get("/api/versions/{version}/models", tags=["meta"])
|
|
@@ -294,7 +330,9 @@ async def list_version_models(version: str):
|
|
| 294 |
if version not in _REGISTRIES:
|
| 295 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 296 |
|
|
|
|
| 297 |
reg = _REGISTRIES[version]
|
|
|
|
| 298 |
rows = []
|
| 299 |
for model_name, info in reg._catalog.items():
|
| 300 |
key = _model_status_key(version, model_name)
|
|
@@ -317,11 +355,13 @@ async def list_version_models(version: str):
|
|
| 317 |
|
| 318 |
@app.post("/api/versions/{version}/models/{model_name}/load", tags=["meta"])
|
| 319 |
async def load_single_model(version: str, model_name: str, background_tasks: BackgroundTasks):
|
| 320 |
-
"""
|
| 321 |
if version not in _REGISTRIES:
|
| 322 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 323 |
|
|
|
|
| 324 |
reg = _REGISTRIES[version]
|
|
|
|
| 325 |
if model_name not in reg._catalog:
|
| 326 |
raise HTTPException(status_code=404, detail=f"Unknown model '{model_name}' in {version}")
|
| 327 |
|
|
@@ -343,7 +383,7 @@ async def load_single_model(version: str, model_name: str, background_tasks: Bac
|
|
| 343 |
raise HTTPException(status_code=500, detail=f"Model '{model_name}' exists on disk but failed to load")
|
| 344 |
|
| 345 |
_model_status[key] = "downloading"
|
| 346 |
-
background_tasks.add_task(
|
| 347 |
return {"status": "downloading", "version": version, "model": model_name}
|
| 348 |
|
| 349 |
|
|
|
|
| 53 |
|
| 54 |
from api.model_registry import registry, registry_v1, registry_v2, registry_v3
|
| 55 |
from api.schemas import HealthResponse
|
| 56 |
+
from scripts.download_models import (
|
| 57 |
+
select_top_models,
|
| 58 |
+
DEFAULT_STARTUP_TOP_MODELS,
|
| 59 |
+
ensure_metadata_first,
|
| 60 |
+
write_datamap,
|
| 61 |
+
)
|
| 62 |
from src.utils.logger import get_logger
|
| 63 |
|
| 64 |
log = get_logger(__name__)
|
|
|
|
| 75 |
async def lifespan(app: FastAPI):
|
| 76 |
"""Start API immediately; bootstrap top v3 models in background."""
|
| 77 |
log.info("Loading model registries …")
|
| 78 |
+
# Hard requirement: metadata must exist first for all versions.
|
| 79 |
+
await asyncio.to_thread(ensure_metadata_first, ["v1", "v2", "v3"])
|
| 80 |
+
for reg in _REGISTRIES.values():
|
| 81 |
+
reg.refresh_metadata()
|
| 82 |
+
|
| 83 |
_version_status["v3"] = "downloading"
|
| 84 |
app.state.v3_bootstrap_task = asyncio.create_task(_bg_bootstrap_v3())
|
| 85 |
log.info("v3 bootstrap started in background — API is available immediately")
|
|
|
|
| 157 |
out = []
|
| 158 |
for v in ["v3", "v2", "v1"]:
|
| 159 |
reg = _REGISTRIES[v]
|
| 160 |
+
reg.ensure_metadata_loaded()
|
| 161 |
on_disk = _version_loaded(v)
|
| 162 |
in_memory = reg.model_count > 0
|
| 163 |
meta = reg._version_meta # from models.json (loaded in __init__)
|
|
|
|
| 188 |
)
|
| 189 |
await proc.wait()
|
| 190 |
if proc.returncode == 0:
|
| 191 |
+
_REGISTRIES[version].refresh_metadata()
|
| 192 |
_REGISTRIES[version].load_all()
|
| 193 |
_version_status[version] = "ready"
|
| 194 |
log.info("Version %s loaded on demand — %d models", version,
|
|
|
|
| 220 |
|
| 221 |
startup_models = select_top_models("v3")
|
| 222 |
if startup_models:
|
| 223 |
+
registry_v3.refresh_metadata()
|
| 224 |
registry_v3.load_all(only_models=set(startup_models))
|
| 225 |
log.info(
|
| 226 |
"v3 registry ready — %d models loaded (startup top-%d set: %s)",
|
|
|
|
| 241 |
log.error("Failed during v3 bootstrap: %s", exc)
|
| 242 |
|
| 243 |
|
| 244 |
+
async def _bg_download_model(version: str, model_name: str) -> None:
|
| 245 |
import sys as _sys
|
| 246 |
|
| 247 |
key = _model_status_key(version, model_name)
|
|
|
|
| 257 |
stderr=asyncio.subprocess.STDOUT,
|
| 258 |
)
|
| 259 |
await proc.wait()
|
| 260 |
+
_REGISTRIES[version].refresh_metadata()
|
| 261 |
+
if proc.returncode == 0 and _REGISTRIES[version].model_on_disk(model_name):
|
| 262 |
+
_model_status[key] = "on_disk"
|
| 263 |
+
log.info("Model %s/%s downloaded on demand", version, model_name)
|
| 264 |
else:
|
| 265 |
_model_status[key] = "error"
|
| 266 |
+
log.error("Failed to download model %s/%s", version, model_name)
|
| 267 |
except Exception as exc:
|
| 268 |
_model_status[key] = "error"
|
| 269 |
+
log.error("Failed to download model %s/%s: %s", version, model_name, exc)
|
| 270 |
|
| 271 |
|
| 272 |
@app.post("/api/versions/{version}/load", tags=["meta"])
|
|
|
|
| 274 |
"""Download + activate a model version from HF Hub (runs in background)."""
|
| 275 |
if version not in _REGISTRIES:
|
| 276 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 277 |
+
await asyncio.to_thread(ensure_metadata_first, [version])
|
| 278 |
+
_REGISTRIES[version].refresh_metadata()
|
| 279 |
if _version_status.get(version) == "downloading":
|
| 280 |
return {"status": "downloading", "version": version}
|
| 281 |
# If artifacts exist on disk but not loaded, just load without downloading
|
|
|
|
| 298 |
"""Return models.json metadata for a specific version."""
|
| 299 |
if version not in _REGISTRIES:
|
| 300 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 301 |
+
await asyncio.to_thread(ensure_metadata_first, [version])
|
| 302 |
+
_REGISTRIES[version].refresh_metadata()
|
| 303 |
meta_path = _artifacts_dir() / version / "models.json"
|
| 304 |
if not meta_path.exists():
|
| 305 |
raise HTTPException(status_code=404, detail=f"models.json not found for {version}")
|
| 306 |
+
datamap_path = _artifacts_dir() / version / "datamap.json"
|
| 307 |
+
if not datamap_path.exists():
|
| 308 |
+
await asyncio.to_thread(write_datamap, version)
|
| 309 |
+
return {
|
| 310 |
+
"models_meta": _json.loads(meta_path.read_text(encoding="utf-8")),
|
| 311 |
+
"datamap": _json.loads(datamap_path.read_text(encoding="utf-8")) if datamap_path.exists() else {},
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@app.get("/api/versions/{version}/datamap", tags=["meta"])
|
| 316 |
+
async def get_version_datamap(version: str):
|
| 317 |
+
"""Return datamap.json for a specific version; generate if missing."""
|
| 318 |
+
if version not in _REGISTRIES:
|
| 319 |
+
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 320 |
+
await asyncio.to_thread(ensure_metadata_first, [version])
|
| 321 |
+
datamap_path = _artifacts_dir() / version / "datamap.json"
|
| 322 |
+
if not datamap_path.exists():
|
| 323 |
+
await asyncio.to_thread(write_datamap, version)
|
| 324 |
+
return _json.loads(datamap_path.read_text(encoding="utf-8"))
|
| 325 |
|
| 326 |
|
| 327 |
@app.get("/api/versions/{version}/models", tags=["meta"])
|
|
|
|
| 330 |
if version not in _REGISTRIES:
|
| 331 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 332 |
|
| 333 |
+
await asyncio.to_thread(ensure_metadata_first, [version])
|
| 334 |
reg = _REGISTRIES[version]
|
| 335 |
+
reg.refresh_metadata()
|
| 336 |
rows = []
|
| 337 |
for model_name, info in reg._catalog.items():
|
| 338 |
key = _model_status_key(version, model_name)
|
|
|
|
| 355 |
|
| 356 |
@app.post("/api/versions/{version}/models/{model_name}/load", tags=["meta"])
|
| 357 |
async def load_single_model(version: str, model_name: str, background_tasks: BackgroundTasks):
|
| 358 |
+
"""Two-step per-model action: download first, then load into memory."""
|
| 359 |
if version not in _REGISTRIES:
|
| 360 |
raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
|
| 361 |
|
| 362 |
+
await asyncio.to_thread(ensure_metadata_first, [version])
|
| 363 |
reg = _REGISTRIES[version]
|
| 364 |
+
reg.refresh_metadata()
|
| 365 |
if model_name not in reg._catalog:
|
| 366 |
raise HTTPException(status_code=404, detail=f"Unknown model '{model_name}' in {version}")
|
| 367 |
|
|
|
|
| 383 |
raise HTTPException(status_code=500, detail=f"Model '{model_name}' exists on disk but failed to load")
|
| 384 |
|
| 385 |
_model_status[key] = "downloading"
|
| 386 |
+
background_tasks.add_task(_bg_download_model, version, model_name)
|
| 387 |
return {"status": "downloading", "version": version, "model": model_name}
|
| 388 |
|
| 389 |
|
api/model_registry.py
CHANGED
|
@@ -99,7 +99,7 @@ def _load_version_meta(version: str) -> dict[str, Any]:
|
|
| 99 |
"""
|
| 100 |
json_path = _ARTIFACTS / version / "models.json"
|
| 101 |
if not json_path.exists():
|
| 102 |
-
log.
|
| 103 |
return {}
|
| 104 |
try:
|
| 105 |
with open(json_path, encoding="utf-8") as fh:
|
|
@@ -197,6 +197,34 @@ class ModelRegistry:
|
|
| 197 |
return False
|
| 198 |
return (self._artifacts / rel).exists()
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
def load_model(self, model_name: str) -> bool:
|
| 201 |
"""Load one model (and ensemble dependencies when applicable).
|
| 202 |
|
|
@@ -222,6 +250,7 @@ class ModelRegistry:
|
|
| 222 |
if only_models is None and self.models:
|
| 223 |
log.debug("Registry already populated — skipping load_all()")
|
| 224 |
return
|
|
|
|
| 225 |
self._detect_device()
|
| 226 |
self._load_scaler()
|
| 227 |
self.feature_cols = self._load_feature_cols()
|
|
@@ -939,6 +968,7 @@ class ModelRegistry:
|
|
| 939 |
|
| 940 |
def list_models(self) -> list[dict[str, Any]]:
|
| 941 |
"""Return full model listing with versioning, metrics, and load status."""
|
|
|
|
| 942 |
all_metrics = self.get_metrics()
|
| 943 |
# Registry version prefix: "v1" -> "1", "v2" -> "2", "v3" -> "3"
|
| 944 |
reg_major = self.version.lstrip("v")
|
|
|
|
| 99 |
"""
|
| 100 |
json_path = _ARTIFACTS / version / "models.json"
|
| 101 |
if not json_path.exists():
|
| 102 |
+
log.info("models.json not found for %s yet — catalog will refresh after metadata bootstrap", version)
|
| 103 |
return {}
|
| 104 |
try:
|
| 105 |
with open(json_path, encoding="utf-8") as fh:
|
|
|
|
| 197 |
return False
|
| 198 |
return (self._artifacts / rel).exists()
|
| 199 |
|
| 200 |
+
def refresh_metadata(self) -> None:
|
| 201 |
+
"""Reload models.json metadata and refresh catalog-derived settings."""
|
| 202 |
+
self._version_meta = _load_version_meta(self.version)
|
| 203 |
+
self._catalog = self._version_meta.get("models", {})
|
| 204 |
+
|
| 205 |
+
json_features = self._version_meta.get("feature_set")
|
| 206 |
+
if json_features:
|
| 207 |
+
self.feature_cols = list(json_features)
|
| 208 |
+
elif self.version == "v3":
|
| 209 |
+
self.feature_cols = list(V3_FEATURE_COLS)
|
| 210 |
+
else:
|
| 211 |
+
self.feature_cols = list(FEATURE_COLS_SCALAR)
|
| 212 |
+
|
| 213 |
+
ensemble_info = self._catalog.get("best_ensemble", {})
|
| 214 |
+
self._ensemble_components = ensemble_info.get("components", [])
|
| 215 |
+
self._ensemble_weights = {}
|
| 216 |
+
for cname in self._ensemble_components:
|
| 217 |
+
cinfo = self._catalog.get(cname, {})
|
| 218 |
+
self._ensemble_weights[cname] = cinfo.get("r2", 1.0)
|
| 219 |
+
|
| 220 |
+
def ensure_metadata_loaded(self) -> None:
|
| 221 |
+
"""Refresh metadata lazily if catalog is empty and models.json exists now."""
|
| 222 |
+
if self._catalog:
|
| 223 |
+
return
|
| 224 |
+
json_path = _ARTIFACTS / self.version / "models.json"
|
| 225 |
+
if json_path.exists():
|
| 226 |
+
self.refresh_metadata()
|
| 227 |
+
|
| 228 |
def load_model(self, model_name: str) -> bool:
|
| 229 |
"""Load one model (and ensemble dependencies when applicable).
|
| 230 |
|
|
|
|
| 250 |
if only_models is None and self.models:
|
| 251 |
log.debug("Registry already populated — skipping load_all()")
|
| 252 |
return
|
| 253 |
+
self.ensure_metadata_loaded()
|
| 254 |
self._detect_device()
|
| 255 |
self._load_scaler()
|
| 256 |
self.feature_cols = self._load_feature_cols()
|
|
|
|
| 968 |
|
| 969 |
def list_models(self) -> list[dict[str, Any]]:
|
| 970 |
"""Return full model listing with versioning, metrics, and load status."""
|
| 971 |
+
self.ensure_metadata_loaded()
|
| 972 |
all_metrics = self.get_metrics()
|
| 973 |
# Registry version prefix: "v1" -> "1", "v2" -> "2", "v3" -> "3"
|
| 974 |
reg_major = self.version.lstrip("v")
|
api/routers/simulate.py
CHANGED
|
@@ -315,17 +315,30 @@ async def simulate_batteries(req: SimulateRequest):
|
|
| 315 |
eol_thr = req.eol_threshold
|
| 316 |
N = req.steps
|
| 317 |
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
# Deep sequence models need per-sample tensors and are not used in this endpoint.
|
| 321 |
-
# Classical + ensemble models use batch predict_array().
|
| 322 |
family = registry_v2.model_meta.get(model_name, {}).get("family", "classical")
|
| 323 |
is_deep = family in ("deep_pytorch", "deep_keras")
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Determine scaler note for logging (mirrors training decision exactly)
|
| 331 |
if registry_v2.model_meta.get(model_name, {}).get("requires_scaling", False):
|
|
@@ -337,8 +350,8 @@ async def simulate_batteries(req: SimulateRequest):
|
|
| 337 |
|
| 338 |
effective_model = "linear_fallback"
|
| 339 |
log.info(
|
| 340 |
-
"simulate: %d batteries x %d steps |
|
| 341 |
-
len(req.batteries), N, model_name, ml_batchable, scaler_note, time_unit,
|
| 342 |
)
|
| 343 |
|
| 344 |
results: list[BatterySimResult] = []
|
|
|
|
| 315 |
eol_thr = req.eol_threshold
|
| 316 |
N = req.steps
|
| 317 |
|
| 318 |
+
requested_model = req.model_name or registry_v2.default_model or "best_ensemble"
|
| 319 |
+
|
| 320 |
+
# Resolve to a batchable loaded model once (to avoid per-battery fallback spam).
|
| 321 |
+
# Priority: requested -> registry default -> first loaded classical model.
|
| 322 |
+
model_name = requested_model
|
| 323 |
+
if requested_model == "best_ensemble":
|
| 324 |
+
ensemble_components = registry_v2.model_meta.get("best_ensemble", {}).get("components", [])
|
| 325 |
+
if not ensemble_components:
|
| 326 |
+
model_name = registry_v2.default_model or ""
|
| 327 |
|
|
|
|
|
|
|
| 328 |
family = registry_v2.model_meta.get(model_name, {}).get("family", "classical")
|
| 329 |
is_deep = family in ("deep_pytorch", "deep_keras")
|
| 330 |
+
|
| 331 |
+
if (model_name not in registry_v2.models) or is_deep:
|
| 332 |
+
fallback_loaded = [
|
| 333 |
+
name for name, meta in registry_v2.model_meta.items()
|
| 334 |
+
if name in registry_v2.models and meta.get("family") == "classical"
|
| 335 |
+
]
|
| 336 |
+
if fallback_loaded:
|
| 337 |
+
model_name = fallback_loaded[0]
|
| 338 |
+
family = registry_v2.model_meta.get(model_name, {}).get("family", "classical")
|
| 339 |
+
is_deep = family in ("deep_pytorch", "deep_keras")
|
| 340 |
+
|
| 341 |
+
ml_batchable = req.use_ml and not is_deep and (model_name == "best_ensemble" or model_name in registry_v2.models)
|
| 342 |
|
| 343 |
# Determine scaler note for logging (mirrors training decision exactly)
|
| 344 |
if registry_v2.model_meta.get(model_name, {}).get("requires_scaling", False):
|
|
|
|
| 350 |
|
| 351 |
effective_model = "linear_fallback"
|
| 352 |
log.info(
|
| 353 |
+
"simulate: %d batteries x %d steps | requested=%s | effective=%s | batchable=%s | scaler=%s | unit=%s",
|
| 354 |
+
len(req.batteries), N, requested_model, model_name, ml_batchable, scaler_note, time_unit,
|
| 355 |
)
|
| 356 |
|
| 357 |
results: list[BatterySimResult] = []
|
api/routers/visualize.py
CHANGED
|
@@ -17,6 +17,7 @@ from fastapi.responses import FileResponse
|
|
| 17 |
|
| 18 |
from api.model_registry import registry, classify_degradation, soh_to_color
|
| 19 |
from api.schemas import BatteryVizData, DashboardData
|
|
|
|
| 20 |
|
| 21 |
router = APIRouter(prefix="/api", tags=["visualization"])
|
| 22 |
|
|
@@ -223,9 +224,24 @@ def _battery_stats_for_version(version: str) -> dict:
|
|
| 223 |
def _build_metrics_payload(version: str) -> dict:
|
| 224 |
_ensure_version(version)
|
| 225 |
root = _version_root(version)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
results = root / "results"
|
| 227 |
reports = root / "reports"
|
| 228 |
models_meta = _safe_read_json(root / "models.json")
|
|
|
|
| 229 |
|
| 230 |
unified = _safe_read_csv_first([results / "unified_results.csv"])
|
| 231 |
classical_results = _safe_read_csv_first([
|
|
@@ -258,9 +274,48 @@ def _build_metrics_payload(version: str) -> dict:
|
|
| 258 |
vae_lstm = _safe_read_json_first([results / "vae_lstm_results.json"])
|
| 259 |
dg_itransformer = _safe_read_json_first([results / "dg_itransformer_results.json"])
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
return {
|
| 262 |
"version": version,
|
| 263 |
"models_meta": models_meta,
|
|
|
|
| 264 |
"unified_results": unified,
|
| 265 |
"classical_results": classical_results,
|
| 266 |
"classical_soh": classical_soh,
|
|
|
|
| 17 |
|
| 18 |
from api.model_registry import registry, classify_degradation, soh_to_color
|
| 19 |
from api.schemas import BatteryVizData, DashboardData
|
| 20 |
+
from scripts.download_models import ensure_metadata_first, download_metrics_bundle
|
| 21 |
|
| 22 |
router = APIRouter(prefix="/api", tags=["visualization"])
|
| 23 |
|
|
|
|
| 224 |
def _build_metrics_payload(version: str) -> dict:
|
| 225 |
_ensure_version(version)
|
| 226 |
root = _version_root(version)
|
| 227 |
+
|
| 228 |
+
# Ensure artifacts required by metrics exist locally for this version.
|
| 229 |
+
try:
|
| 230 |
+
ensure_metadata_first([version])
|
| 231 |
+
results_dir = root / "results"
|
| 232 |
+
figures_dir = root / "figures"
|
| 233 |
+
has_results = results_dir.exists() and any(results_dir.glob("*"))
|
| 234 |
+
has_figures = figures_dir.exists() and any(figures_dir.glob("*"))
|
| 235 |
+
if not has_results and not has_figures:
|
| 236 |
+
download_metrics_bundle(version)
|
| 237 |
+
except Exception:
|
| 238 |
+
# Keep endpoint resilient; payload will still be built from whatever exists.
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
results = root / "results"
|
| 242 |
reports = root / "reports"
|
| 243 |
models_meta = _safe_read_json(root / "models.json")
|
| 244 |
+
datamap = _safe_read_json(root / "datamap.json")
|
| 245 |
|
| 246 |
unified = _safe_read_csv_first([results / "unified_results.csv"])
|
| 247 |
classical_results = _safe_read_csv_first([
|
|
|
|
| 274 |
vae_lstm = _safe_read_json_first([results / "vae_lstm_results.json"])
|
| 275 |
dg_itransformer = _safe_read_json_first([results / "dg_itransformer_results.json"])
|
| 276 |
|
| 277 |
+
# Fallback: build unified/classical-like rows directly from models.json when
|
| 278 |
+
# result CSVs are not yet downloaded for a version.
|
| 279 |
+
if not unified and isinstance(models_meta, dict):
|
| 280 |
+
model_rows = []
|
| 281 |
+
for name, info in (models_meta.get("models") or {}).items():
|
| 282 |
+
if not isinstance(info, dict):
|
| 283 |
+
continue
|
| 284 |
+
model_rows.append({
|
| 285 |
+
"model": name,
|
| 286 |
+
"family": info.get("family"),
|
| 287 |
+
"R2": info.get("r2"),
|
| 288 |
+
"MAE": info.get("mae"),
|
| 289 |
+
"RMSE": info.get("rmse"),
|
| 290 |
+
"MAPE": info.get("mape"),
|
| 291 |
+
"within_5pct": info.get("within_5pct"),
|
| 292 |
+
"f1_macro": info.get("f1_macro"),
|
| 293 |
+
"f1_weighted": info.get("f1_weighted"),
|
| 294 |
+
})
|
| 295 |
+
unified = model_rows
|
| 296 |
+
if not classical_results:
|
| 297 |
+
classical_results = [r for r in model_rows if (r.get("family") or "").startswith("classical")]
|
| 298 |
+
|
| 299 |
+
# Fallback summaries derived from unified rows
|
| 300 |
+
if not training_summary and unified:
|
| 301 |
+
valid_r2 = [r.get("R2") for r in unified if isinstance(r.get("R2"), (int, float))]
|
| 302 |
+
valid_w5 = [r.get("within_5pct") for r in unified if isinstance(r.get("within_5pct"), (int, float))]
|
| 303 |
+
best = max(unified, key=lambda r: r.get("R2") if isinstance(r.get("R2"), (int, float)) else -999)
|
| 304 |
+
training_summary = {
|
| 305 |
+
"best_model": best.get("model"),
|
| 306 |
+
"best_r2": best.get("R2"),
|
| 307 |
+
"best_within_5pct": best.get("within_5pct"),
|
| 308 |
+
"total_models": len(unified),
|
| 309 |
+
"mean_within_5pct": (sum(valid_w5) / len(valid_w5)) if valid_w5 else None,
|
| 310 |
+
"passed_models": sum(1 for v in valid_w5 if v >= 95.0),
|
| 311 |
+
"pass_rate_pct": (sum(1 for v in valid_w5 if v >= 95.0) / len(valid_w5) * 100.0) if valid_w5 else 0.0,
|
| 312 |
+
"mean_r2": (sum(valid_r2) / len(valid_r2)) if valid_r2 else None,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
return {
|
| 316 |
"version": version,
|
| 317 |
"models_meta": models_meta,
|
| 318 |
+
"datamap": datamap,
|
| 319 |
"unified_results": unified,
|
| 320 |
"classical_results": classical_results,
|
| 321 |
"classical_soh": classical_soh,
|
frontend/src/App.tsx
CHANGED
|
@@ -82,7 +82,7 @@ export default function App() {
|
|
| 82 |
{activeTab === "predict" && <PredictionForm apiVersion={apiVersion} />}
|
| 83 |
{activeTab === "graphs" && <GraphPanel />}
|
| 84 |
{activeTab === "recommend" && <RecommendationPanel apiVersion={apiVersion} />}
|
| 85 |
-
{activeTab === "metrics" && <MetricsPanel />}
|
| 86 |
{activeTab === "paper" && <ResearchPaper />}
|
| 87 |
</main>
|
| 88 |
|
|
|
|
| 82 |
{activeTab === "predict" && <PredictionForm apiVersion={apiVersion} />}
|
| 83 |
{activeTab === "graphs" && <GraphPanel />}
|
| 84 |
{activeTab === "recommend" && <RecommendationPanel apiVersion={apiVersion} />}
|
| 85 |
+
{activeTab === "metrics" && <MetricsPanel apiVersion={apiVersion} />}
|
| 86 |
{activeTab === "paper" && <ResearchPaper />}
|
| 87 |
</main>
|
| 88 |
|
frontend/src/components/MetricsPanel.tsx
CHANGED
|
@@ -99,7 +99,11 @@ function StatCard({ label, value, color = "text-green-400", subtitle, icon, tren
|
|
| 99 |
);
|
| 100 |
}
|
| 101 |
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
const [data, setData] = useState<MetricsData | null>(null);
|
| 104 |
const [loading, setLoading] = useState(true);
|
| 105 |
const [error, setError] = useState<string | null>(null);
|
|
@@ -114,6 +118,10 @@ export default function MetricsPanel() {
|
|
| 114 |
const [compareSelected, setCompareSelected] = useState<string[]>([]);
|
| 115 |
const [chartView, setChartView] = useState<"bar" | "radar" | "scatter">("bar");
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
useEffect(() => {
|
| 118 |
setLoading(true);
|
| 119 |
setError(null);
|
|
|
|
| 99 |
);
|
| 100 |
}
|
| 101 |
|
| 102 |
+
type Props = {
|
| 103 |
+
apiVersion?: "v1" | "v2" | "v3";
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
export default function MetricsPanel({ apiVersion = "v3" }: Props) {
|
| 107 |
const [data, setData] = useState<MetricsData | null>(null);
|
| 108 |
const [loading, setLoading] = useState(true);
|
| 109 |
const [error, setError] = useState<string | null>(null);
|
|
|
|
| 118 |
const [compareSelected, setCompareSelected] = useState<string[]>([]);
|
| 119 |
const [chartView, setChartView] = useState<"bar" | "radar" | "scatter">("bar");
|
| 120 |
|
| 121 |
+
useEffect(() => {
|
| 122 |
+
setActiveVersion(apiVersion);
|
| 123 |
+
}, [apiVersion]);
|
| 124 |
+
|
| 125 |
useEffect(() => {
|
| 126 |
setLoading(true);
|
| 127 |
setError(null);
|
frontend/src/components/RecommendationPanel.tsx
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
import { useEffect, useState, useMemo } from "react";
|
|
|
|
|
|
|
|
|
|
| 2 |
import {
|
| 3 |
BarChart, Bar, XAxis, YAxis, Tooltip, Legend, ResponsiveContainer,
|
| 4 |
CartesianGrid, RadarChart, Radar, PolarGrid, PolarAngleAxis, PolarRadiusAxis,
|
|
@@ -50,6 +53,40 @@ type Props = {
|
|
| 50 |
apiVersion: "v1" | "v2" | "v3";
|
| 51 |
};
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
export default function RecommendationPanel({ apiVersion }: Props) {
|
| 54 |
const [batteryId, setBatteryId] = useState("B0005");
|
| 55 |
const [currentCycle, setCurrentCycle] = useState(100);
|
|
@@ -62,6 +99,7 @@ export default function RecommendationPanel({ apiVersion }: Props) {
|
|
| 62 |
const [loading, setLoading] = useState(false);
|
| 63 |
const [error, setError] = useState<string | null>(null);
|
| 64 |
const [expandedRow, setExpandedRow] = useState<number | null>(null);
|
|
|
|
| 65 |
const [chartTab, setChartTab] = useState<"rul" | "params" | "radar">("rul");
|
| 66 |
|
| 67 |
// Fetch available loaded models for selector
|
|
@@ -84,6 +122,7 @@ export default function RecommendationPanel({ apiVersion }: Props) {
|
|
| 84 |
...(selectedModel ? { model_name: selectedModel } : {}),
|
| 85 |
});
|
| 86 |
setResult(res);
|
|
|
|
| 87 |
} catch (e: any) {
|
| 88 |
setError(e.response?.data?.detail || e.message);
|
| 89 |
} finally {
|
|
@@ -127,6 +166,8 @@ export default function RecommendationPanel({ apiVersion }: Props) {
|
|
| 127 |
|
| 128 |
const baseline = result?.recommendations[0];
|
| 129 |
const bestOnly = result?.recommendations.find((r) => r.rank === 1);
|
|
|
|
|
|
|
| 130 |
|
| 131 |
return (
|
| 132 |
<div className="space-y-5">
|
|
@@ -327,18 +368,19 @@ export default function RecommendationPanel({ apiVersion }: Props) {
|
|
| 327 |
)}
|
| 328 |
</div>
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
<div className="
|
| 333 |
-
<div className="flex items-center
|
| 334 |
-
<
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
| 338 |
</div>
|
| 339 |
-
<
|
| 340 |
-
</div>
|
| 341 |
-
<div className="overflow-x-auto">
|
| 342 |
<table className="w-full text-sm">
|
| 343 |
<thead>
|
| 344 |
<tr className="text-gray-500 border-b border-gray-800 bg-gray-950/50">
|
|
@@ -361,9 +403,12 @@ export default function RecommendationPanel({ apiVersion }: Props) {
|
|
| 361 |
<tr
|
| 362 |
key={rec.rank}
|
| 363 |
className={`border-b border-gray-800/40 hover:bg-gray-800/40 transition-colors cursor-pointer ${
|
| 364 |
-
rec.rank === 1 ? "bg-yellow-900/10" : ""
|
| 365 |
}`}
|
| 366 |
-
onClick={() =>
|
|
|
|
|
|
|
|
|
|
| 367 |
>
|
| 368 |
<td className="py-2.5 px-3">
|
| 369 |
<span className="flex items-center"><RankIcon rank={rec.rank} /></span>
|
|
@@ -421,6 +466,23 @@ export default function RecommendationPanel({ apiVersion }: Props) {
|
|
| 421 |
})}
|
| 422 |
</tbody>
|
| 423 |
</table>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
</div>
|
| 425 |
</div>
|
| 426 |
</div>
|
|
|
|
| 1 |
import { useEffect, useState, useMemo } from "react";
|
| 2 |
+
import { Canvas } from "@react-three/fiber";
|
| 3 |
+
import { OrbitControls, Html } from "@react-three/drei";
|
| 4 |
+
import * as THREE from "three";
|
| 5 |
import {
|
| 6 |
BarChart, Bar, XAxis, YAxis, Tooltip, Legend, ResponsiveContainer,
|
| 7 |
CartesianGrid, RadarChart, Radar, PolarGrid, PolarAngleAxis, PolarRadiusAxis,
|
|
|
|
| 53 |
apiVersion: "v1" | "v2" | "v3";
|
| 54 |
};
|
| 55 |
|
| 56 |
+
function Battery3D({ currentSoh, projectedSoh }: { currentSoh: number; projectedSoh: number }) {
|
| 57 |
+
const clamped = Math.max(0, Math.min(100, projectedSoh));
|
| 58 |
+
const fillHeight = 0.2 + (clamped / 100) * 2.6;
|
| 59 |
+
const fillY = -1.4 + fillHeight / 2;
|
| 60 |
+
const color = clamped >= 90 ? "#22c55e" : clamped >= 80 ? "#eab308" : clamped >= 70 ? "#f97316" : "#ef4444";
|
| 61 |
+
|
| 62 |
+
return (
|
| 63 |
+
<Canvas camera={{ position: [3.2, 2.2, 3.2], fov: 45 }}>
|
| 64 |
+
<ambientLight intensity={0.7} />
|
| 65 |
+
<directionalLight position={[3, 5, 2]} intensity={1.1} />
|
| 66 |
+
<group>
|
| 67 |
+
<mesh>
|
| 68 |
+
<cylinderGeometry args={[0.9, 0.9, 3.1, 48]} />
|
| 69 |
+
<meshPhysicalMaterial color="#94a3b8" transparent opacity={0.16} transmission={0.8} roughness={0.05} />
|
| 70 |
+
</mesh>
|
| 71 |
+
<mesh position={[0, fillY, 0]}>
|
| 72 |
+
<cylinderGeometry args={[0.76, 0.76, fillHeight, 48]} />
|
| 73 |
+
<meshStandardMaterial color={color} emissive={color} emissiveIntensity={0.3} roughness={0.35} />
|
| 74 |
+
</mesh>
|
| 75 |
+
<mesh position={[0, 1.75, 0]}>
|
| 76 |
+
<cylinderGeometry args={[0.24, 0.24, 0.22, 24]} />
|
| 77 |
+
<meshStandardMaterial color="#e5e7eb" metalness={0.9} roughness={0.1} />
|
| 78 |
+
</mesh>
|
| 79 |
+
<Html position={[0, -2.15, 0]} center>
|
| 80 |
+
<div className="text-xs text-gray-300 bg-gray-900/80 border border-gray-700 rounded px-2 py-1">
|
| 81 |
+
SOH: <span className="font-semibold text-white">{currentSoh.toFixed(1)}%</span> → <span className="font-semibold" style={{ color }}>{clamped.toFixed(1)}%</span>
|
| 82 |
+
</div>
|
| 83 |
+
</Html>
|
| 84 |
+
</group>
|
| 85 |
+
<OrbitControls enablePan={false} maxDistance={6} minDistance={2.4} />
|
| 86 |
+
</Canvas>
|
| 87 |
+
);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
export default function RecommendationPanel({ apiVersion }: Props) {
|
| 91 |
const [batteryId, setBatteryId] = useState("B0005");
|
| 92 |
const [currentCycle, setCurrentCycle] = useState(100);
|
|
|
|
| 99 |
const [loading, setLoading] = useState(false);
|
| 100 |
const [error, setError] = useState<string | null>(null);
|
| 101 |
const [expandedRow, setExpandedRow] = useState<number | null>(null);
|
| 102 |
+
const [selectedRank, setSelectedRank] = useState<number>(1);
|
| 103 |
const [chartTab, setChartTab] = useState<"rul" | "params" | "radar">("rul");
|
| 104 |
|
| 105 |
// Fetch available loaded models for selector
|
|
|
|
| 122 |
...(selectedModel ? { model_name: selectedModel } : {}),
|
| 123 |
});
|
| 124 |
setResult(res);
|
| 125 |
+
setSelectedRank(1);
|
| 126 |
} catch (e: any) {
|
| 127 |
setError(e.response?.data?.detail || e.message);
|
| 128 |
} finally {
|
|
|
|
| 166 |
|
| 167 |
const baseline = result?.recommendations[0];
|
| 168 |
const bestOnly = result?.recommendations.find((r) => r.rank === 1);
|
| 169 |
+
const selectedRec = result?.recommendations.find((r) => r.rank === selectedRank) ?? bestOnly;
|
| 170 |
+
const projectedSoh = Math.min(100, currentSoh + ((selectedRec?.rul_improvement_pct ?? 0) / 6));
|
| 171 |
|
| 172 |
return (
|
| 173 |
<div className="space-y-5">
|
|
|
|
| 368 |
)}
|
| 369 |
</div>
|
| 370 |
|
| 371 |
+
<div className="grid grid-cols-1 xl:grid-cols-2 gap-5">
|
| 372 |
+
{/* Recommendations table */}
|
| 373 |
+
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
|
| 374 |
+
<div className="p-4 border-b border-gray-800 flex items-center justify-between">
|
| 375 |
+
<div className="flex items-center gap-2">
|
| 376 |
+
<Trophy className="w-4 h-4 text-yellow-400" />
|
| 377 |
+
<span className="text-sm font-semibold text-gray-300">
|
| 378 |
+
Recommendations for {result.battery_id} — SOH: {result.current_soh}%
|
| 379 |
+
</span>
|
| 380 |
+
</div>
|
| 381 |
+
<span className="text-xs text-gray-500">{result.recommendations.length} configs</span>
|
| 382 |
</div>
|
| 383 |
+
<div className="overflow-x-auto">
|
|
|
|
|
|
|
| 384 |
<table className="w-full text-sm">
|
| 385 |
<thead>
|
| 386 |
<tr className="text-gray-500 border-b border-gray-800 bg-gray-950/50">
|
|
|
|
| 403 |
<tr
|
| 404 |
key={rec.rank}
|
| 405 |
className={`border-b border-gray-800/40 hover:bg-gray-800/40 transition-colors cursor-pointer ${
|
| 406 |
+
rec.rank === selectedRank ? "bg-emerald-900/20" : rec.rank === 1 ? "bg-yellow-900/10" : ""
|
| 407 |
}`}
|
| 408 |
+
onClick={() => {
|
| 409 |
+
setExpandedRow(expanded ? null : rec.rank);
|
| 410 |
+
setSelectedRank(rec.rank);
|
| 411 |
+
}}
|
| 412 |
>
|
| 413 |
<td className="py-2.5 px-3">
|
| 414 |
<span className="flex items-center"><RankIcon rank={rec.rank} /></span>
|
|
|
|
| 466 |
})}
|
| 467 |
</tbody>
|
| 468 |
</table>
|
| 469 |
+
</div>
|
| 470 |
+
</div>
|
| 471 |
+
|
| 472 |
+
{/* Interactive 3D battery panel */}
|
| 473 |
+
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
|
| 474 |
+
<div className="p-4 border-b border-gray-800 flex items-center justify-between">
|
| 475 |
+
<span className="text-sm font-semibold text-gray-300">Interactive Battery Impact</span>
|
| 476 |
+
<span className="text-xs text-emerald-400">Selected: #{selectedRec?.rank ?? 1}</span>
|
| 477 |
+
</div>
|
| 478 |
+
<div className="h-105">
|
| 479 |
+
<Battery3D currentSoh={currentSoh} projectedSoh={projectedSoh} />
|
| 480 |
+
</div>
|
| 481 |
+
<div className="p-4 border-t border-gray-800 text-xs text-gray-400 space-y-1">
|
| 482 |
+
<div>RUL: <span className="text-white font-semibold">{selectedRec?.predicted_rul.toFixed(0) ?? "-"} cycles</span></div>
|
| 483 |
+
<div>Gain: <span className="text-emerald-400 font-semibold">{selectedRec && selectedRec.rul_improvement > 0 ? "+" : ""}{selectedRec?.rul_improvement.toFixed(0) ?? "-"} cycles</span></div>
|
| 484 |
+
<div>Conditions: <span className="text-white">{selectedRec?.ambient_temperature}°C, {selectedRec?.discharge_current}A, {selectedRec?.cutoff_voltage}V</span></div>
|
| 485 |
+
</div>
|
| 486 |
</div>
|
| 487 |
</div>
|
| 488 |
</div>
|
frontend/src/components/ResearchPaper.tsx
CHANGED
|
@@ -32,16 +32,13 @@ const mdComponents: React.ComponentProps<typeof ReactMarkdown>["components"] = {
|
|
| 32 |
|
| 33 |
// ── Lists ─────────────────────────────────────────────────────────────
|
| 34 |
ul: ({ children }) => (
|
| 35 |
-
<ul className="space-y-1.5 mb-4
|
| 36 |
),
|
| 37 |
ol: ({ children }) => (
|
| 38 |
<ol className="list-decimal list-inside space-y-1.5 mb-4 text-gray-300 text-sm">{children}</ol>
|
| 39 |
),
|
| 40 |
li: ({ children }) => (
|
| 41 |
-
<li className="
|
| 42 |
-
<span className="text-green-400 mt-1 shrink-0">•</span>
|
| 43 |
-
<span>{children}</span>
|
| 44 |
-
</li>
|
| 45 |
),
|
| 46 |
|
| 47 |
// ── Code ─────────────────────────────────────────────────────────────
|
|
@@ -189,20 +186,27 @@ export default function ResearchPaper() {
|
|
| 189 |
const [usedFallback, setUsedFallback] = useState(false);
|
| 190 |
|
| 191 |
useEffect(() => {
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
}, []);
|
| 207 |
|
| 208 |
return (
|
|
|
|
| 32 |
|
| 33 |
// ── Lists ─────────────────────────────────────────────────────────────
|
| 34 |
ul: ({ children }) => (
|
| 35 |
+
<ul className="list-disc list-inside space-y-1.5 mb-4 text-gray-300 text-sm">{children}</ul>
|
| 36 |
),
|
| 37 |
ol: ({ children }) => (
|
| 38 |
<ol className="list-decimal list-inside space-y-1.5 mb-4 text-gray-300 text-sm">{children}</ol>
|
| 39 |
),
|
| 40 |
li: ({ children }) => (
|
| 41 |
+
<li className="text-gray-300 text-sm">{children}</li>
|
|
|
|
|
|
|
|
|
|
| 42 |
),
|
| 43 |
|
| 44 |
// ── Code ─────────────────────────────────────────────────────────────
|
|
|
|
| 186 |
const [usedFallback, setUsedFallback] = useState(false);
|
| 187 |
|
| 188 |
useEffect(() => {
|
| 189 |
+
const sources = ["/research_paper.md", "/docs/research_paper.md"];
|
| 190 |
+
const tryLoad = async () => {
|
| 191 |
+
for (const src of sources) {
|
| 192 |
+
try {
|
| 193 |
+
const r = await fetch(src);
|
| 194 |
+
if (!r.ok) continue;
|
| 195 |
+
const md = await r.text();
|
| 196 |
+
if (md && md.trim().length > 0) {
|
| 197 |
+
setMarkdown(md);
|
| 198 |
+
setLoading(false);
|
| 199 |
+
return;
|
| 200 |
+
}
|
| 201 |
+
} catch {
|
| 202 |
+
// Try next source
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
setMarkdown(FALLBACK);
|
| 206 |
+
setUsedFallback(true);
|
| 207 |
+
setLoading(false);
|
| 208 |
+
};
|
| 209 |
+
tryLoad();
|
| 210 |
}, []);
|
| 211 |
|
| 212 |
return (
|
frontend/src/components/VersionSelector.tsx
CHANGED
|
@@ -11,11 +11,10 @@
|
|
| 11 |
|
| 12 |
import { useCallback, useEffect, useRef, useState } from "react";
|
| 13 |
import {
|
| 14 |
-
ChevronRight,
|
| 15 |
} from "lucide-react";
|
| 16 |
import {
|
| 17 |
fetchVersions,
|
| 18 |
-
loadVersion,
|
| 19 |
VersionInfo,
|
| 20 |
fetchVersionModels,
|
| 21 |
loadVersionModel,
|
|
@@ -30,7 +29,6 @@ interface Props {
|
|
| 30 |
export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
| 31 |
const [open, setOpen] = useState(false);
|
| 32 |
const [versions, setVersions] = useState<VersionInfo[]>([]);
|
| 33 |
-
const [busy, setBusy] = useState<string | null>(null);
|
| 34 |
const [modelBusy, setModelBusy] = useState<string | null>(null);
|
| 35 |
const [expandedVersion, setExpandedVersion] = useState<string | null>(null);
|
| 36 |
const [versionModels, setVersionModels] = useState<Record<string, VersionModelInfo[]>>({});
|
|
@@ -49,19 +47,14 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 49 |
|
| 50 |
// Poll while a download is in progress
|
| 51 |
useEffect(() => {
|
| 52 |
-
const hasDownloading = versions.some((v) => v.status === "downloading");
|
| 53 |
const hasModelDownloading = Object.values(versionModels)
|
| 54 |
.some((rows) => rows.some((m) => m.status === "downloading"));
|
| 55 |
-
if (hasDownloading && !pollRef.current) {
|
| 56 |
-
pollRef.current = setInterval(refresh, 2500);
|
| 57 |
-
}
|
| 58 |
if (hasModelDownloading && !pollRef.current) {
|
| 59 |
pollRef.current = setInterval(refresh, 2500);
|
| 60 |
}
|
| 61 |
-
if (!
|
| 62 |
clearInterval(pollRef.current);
|
| 63 |
pollRef.current = null;
|
| 64 |
-
setBusy(null);
|
| 65 |
setModelBusy(null);
|
| 66 |
}
|
| 67 |
return () => {
|
|
@@ -90,25 +83,8 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 90 |
return () => document.removeEventListener("mousedown", handler);
|
| 91 |
}, [open]);
|
| 92 |
|
| 93 |
-
const handleDownloadOrLoad = async (version: string) => {
|
| 94 |
-
setBusy(version);
|
| 95 |
-
try {
|
| 96 |
-
const res = await loadVersion(version);
|
| 97 |
-
if (res.status === "ready") {
|
| 98 |
-
// Loaded instantly from disk — auto-switch
|
| 99 |
-
onSwitch(version as "v1" | "v2" | "v3");
|
| 100 |
-
setBusy(null);
|
| 101 |
-
setOpen(false);
|
| 102 |
-
}
|
| 103 |
-
refresh();
|
| 104 |
-
} catch {
|
| 105 |
-
setBusy(null);
|
| 106 |
-
}
|
| 107 |
-
};
|
| 108 |
-
|
| 109 |
const handleSwitch = (version: string) => {
|
| 110 |
onSwitch(version as "v1" | "v2" | "v3");
|
| 111 |
-
setOpen(false);
|
| 112 |
};
|
| 113 |
|
| 114 |
const handleExpandVersion = async (version: string) => {
|
|
@@ -139,7 +115,7 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 139 |
};
|
| 140 |
|
| 141 |
const statusForVersion = (v: VersionInfo) => {
|
| 142 |
-
const isDownloading = v.status === "downloading"
|
| 143 |
const isError = v.status === "error";
|
| 144 |
const isLoaded = v.loaded && v.model_count > 0;
|
| 145 |
const isOnDisk = v.status === "on_disk" || (v.on_disk && !isLoaded && !isDownloading && !isError);
|
|
@@ -147,18 +123,6 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 147 |
return { isDownloading, isError, isLoaded, isOnDisk, isNotDownloaded };
|
| 148 |
};
|
| 149 |
|
| 150 |
-
// Auto-switch when download completes
|
| 151 |
-
useEffect(() => {
|
| 152 |
-
if (busy) {
|
| 153 |
-
const v = versions.find((ver) => ver.id === busy);
|
| 154 |
-
if (v && v.status === "ready" && v.loaded && v.model_count > 0) {
|
| 155 |
-
onSwitch(busy as "v1" | "v2" | "v3");
|
| 156 |
-
setBusy(null);
|
| 157 |
-
setOpen(false);
|
| 158 |
-
}
|
| 159 |
-
}
|
| 160 |
-
}, [versions, busy, onSwitch]);
|
| 161 |
-
|
| 162 |
const activeDisplay = versions.find((v) => v.id === activeVersion)?.display
|
| 163 |
?? `v${activeVersion[1]}.0`;
|
| 164 |
|
|
@@ -195,7 +159,7 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 195 |
</div>
|
| 196 |
|
| 197 |
{versions.map((v) => {
|
| 198 |
-
const { isDownloading,
|
| 199 |
const isActiveVersion = v.id === activeVersion;
|
| 200 |
const isExpanded = expandedVersion === v.id;
|
| 201 |
const models = versionModels[v.id] ?? [];
|
|
@@ -207,7 +171,13 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 207 |
hover:bg-gray-800/60 transition-colors"
|
| 208 |
>
|
| 209 |
<div className="flex-1 min-w-0">
|
| 210 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
{isActiveVersion && (
|
| 212 |
<span className="ml-2 text-xs text-green-400 italic">Active Version</span>
|
| 213 |
)}
|
|
@@ -221,72 +191,20 @@ export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
|
| 221 |
{v.model_count} models
|
| 222 |
</span>
|
| 223 |
)}
|
| 224 |
-
{isOnDisk && (
|
| 225 |
-
<span className="ml-2 text-xs text-yellow-400">on disk</span>
|
| 226 |
-
)}
|
| 227 |
-
{isError && (
|
| 228 |
-
<span className="ml-2 text-xs text-red-400">error</span>
|
| 229 |
-
)}
|
| 230 |
{isDownloading && (
|
| 231 |
<span className="ml-2 text-xs text-yellow-400 animate-pulse">
|
| 232 |
downloading...
|
| 233 |
</span>
|
| 234 |
)}
|
| 235 |
-
{isNotDownloaded && (
|
| 236 |
-
<span className="ml-2 text-xs text-gray-600">not downloaded</span>
|
| 237 |
-
)}
|
| 238 |
</div>
|
| 239 |
|
| 240 |
<div className="flex items-center gap-1 shrink-0 ml-2">
|
| 241 |
-
{isLoaded && !isActiveVersion && (
|
| 242 |
-
<button
|
| 243 |
-
onClick={() => handleSwitch(v.id)}
|
| 244 |
-
className="px-2.5 py-1 rounded text-xs font-medium
|
| 245 |
-
bg-green-700 hover:bg-green-600 text-white transition-colors"
|
| 246 |
-
title={`Switch to ${v.display}`}
|
| 247 |
-
>
|
| 248 |
-
Switch
|
| 249 |
-
</button>
|
| 250 |
-
)}
|
| 251 |
-
|
| 252 |
-
{isOnDisk && !isDownloading && (
|
| 253 |
-
<button
|
| 254 |
-
onClick={() => handleDownloadOrLoad(v.id)}
|
| 255 |
-
className="flex items-center gap-1 px-2.5 py-1 rounded text-xs font-medium
|
| 256 |
-
bg-blue-700 hover:bg-blue-600 text-white transition-colors"
|
| 257 |
-
title={`Load ${v.display} into memory`}
|
| 258 |
-
>
|
| 259 |
-
<HardDrive className="w-3 h-3" />
|
| 260 |
-
Load
|
| 261 |
-
</button>
|
| 262 |
-
)}
|
| 263 |
-
|
| 264 |
-
{isNotDownloaded && !isDownloading && !isError && (
|
| 265 |
-
<button
|
| 266 |
-
onClick={() => handleDownloadOrLoad(v.id)}
|
| 267 |
-
className="flex items-center gap-1 px-2.5 py-1 rounded text-xs font-medium
|
| 268 |
-
bg-gray-700 hover:bg-gray-600 text-gray-300 hover:text-white transition-colors"
|
| 269 |
-
title={`Download ${v.display} from HF Hub`}
|
| 270 |
-
>
|
| 271 |
-
<Download className="w-3 h-3" />
|
| 272 |
-
Download
|
| 273 |
-
</button>
|
| 274 |
-
)}
|
| 275 |
-
|
| 276 |
{isDownloading && (
|
| 277 |
<RefreshCw className="w-3.5 h-3.5 text-yellow-400 animate-spin" />
|
| 278 |
)}
|
| 279 |
|
| 280 |
-
{
|
| 281 |
-
<
|
| 282 |
-
onClick={() => handleDownloadOrLoad(v.id)}
|
| 283 |
-
className="flex items-center gap-1 px-2.5 py-1 rounded text-xs font-medium
|
| 284 |
-
bg-gray-700 hover:bg-red-700 text-red-400 hover:text-white transition-colors"
|
| 285 |
-
title="Retry download"
|
| 286 |
-
>
|
| 287 |
-
<AlertCircle className="w-3 h-3" />
|
| 288 |
-
Retry
|
| 289 |
-
</button>
|
| 290 |
)}
|
| 291 |
|
| 292 |
<button
|
|
|
|
| 11 |
|
| 12 |
import { useCallback, useEffect, useRef, useState } from "react";
|
| 13 |
import {
|
| 14 |
+
ChevronRight, Check, RefreshCw, Layers,
|
| 15 |
} from "lucide-react";
|
| 16 |
import {
|
| 17 |
fetchVersions,
|
|
|
|
| 18 |
VersionInfo,
|
| 19 |
fetchVersionModels,
|
| 20 |
loadVersionModel,
|
|
|
|
| 29 |
export default function VersionSelector({ activeVersion, onSwitch }: Props) {
|
| 30 |
const [open, setOpen] = useState(false);
|
| 31 |
const [versions, setVersions] = useState<VersionInfo[]>([]);
|
|
|
|
| 32 |
const [modelBusy, setModelBusy] = useState<string | null>(null);
|
| 33 |
const [expandedVersion, setExpandedVersion] = useState<string | null>(null);
|
| 34 |
const [versionModels, setVersionModels] = useState<Record<string, VersionModelInfo[]>>({});
|
|
|
|
| 47 |
|
| 48 |
// Poll while a download is in progress
|
| 49 |
useEffect(() => {
|
|
|
|
| 50 |
const hasModelDownloading = Object.values(versionModels)
|
| 51 |
.some((rows) => rows.some((m) => m.status === "downloading"));
|
|
|
|
|
|
|
|
|
|
| 52 |
if (hasModelDownloading && !pollRef.current) {
|
| 53 |
pollRef.current = setInterval(refresh, 2500);
|
| 54 |
}
|
| 55 |
+
if (!hasModelDownloading && pollRef.current) {
|
| 56 |
clearInterval(pollRef.current);
|
| 57 |
pollRef.current = null;
|
|
|
|
| 58 |
setModelBusy(null);
|
| 59 |
}
|
| 60 |
return () => {
|
|
|
|
| 83 |
return () => document.removeEventListener("mousedown", handler);
|
| 84 |
}, [open]);
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
const handleSwitch = (version: string) => {
|
| 87 |
onSwitch(version as "v1" | "v2" | "v3");
|
|
|
|
| 88 |
};
|
| 89 |
|
| 90 |
const handleExpandVersion = async (version: string) => {
|
|
|
|
| 115 |
};
|
| 116 |
|
| 117 |
const statusForVersion = (v: VersionInfo) => {
|
| 118 |
+
const isDownloading = v.status === "downloading";
|
| 119 |
const isError = v.status === "error";
|
| 120 |
const isLoaded = v.loaded && v.model_count > 0;
|
| 121 |
const isOnDisk = v.status === "on_disk" || (v.on_disk && !isLoaded && !isDownloading && !isError);
|
|
|
|
| 123 |
return { isDownloading, isError, isLoaded, isOnDisk, isNotDownloaded };
|
| 124 |
};
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
const activeDisplay = versions.find((v) => v.id === activeVersion)?.display
|
| 127 |
?? `v${activeVersion[1]}.0`;
|
| 128 |
|
|
|
|
| 159 |
</div>
|
| 160 |
|
| 161 |
{versions.map((v) => {
|
| 162 |
+
const { isDownloading, isLoaded } = statusForVersion(v);
|
| 163 |
const isActiveVersion = v.id === activeVersion;
|
| 164 |
const isExpanded = expandedVersion === v.id;
|
| 165 |
const models = versionModels[v.id] ?? [];
|
|
|
|
| 171 |
hover:bg-gray-800/60 transition-colors"
|
| 172 |
>
|
| 173 |
<div className="flex-1 min-w-0">
|
| 174 |
+
<button
|
| 175 |
+
onClick={() => handleSwitch(v.id)}
|
| 176 |
+
className={`text-sm font-medium ${isActiveVersion ? "text-white" : "text-gray-200 hover:text-white"}`}
|
| 177 |
+
title={`Switch to ${v.display}`}
|
| 178 |
+
>
|
| 179 |
+
{v.display}
|
| 180 |
+
</button>
|
| 181 |
{isActiveVersion && (
|
| 182 |
<span className="ml-2 text-xs text-green-400 italic">Active Version</span>
|
| 183 |
)}
|
|
|
|
| 191 |
{v.model_count} models
|
| 192 |
</span>
|
| 193 |
)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
{isDownloading && (
|
| 195 |
<span className="ml-2 text-xs text-yellow-400 animate-pulse">
|
| 196 |
downloading...
|
| 197 |
</span>
|
| 198 |
)}
|
|
|
|
|
|
|
|
|
|
| 199 |
</div>
|
| 200 |
|
| 201 |
<div className="flex items-center gap-1 shrink-0 ml-2">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
{isDownloading && (
|
| 203 |
<RefreshCw className="w-3.5 h-3.5 text-yellow-400 animate-spin" />
|
| 204 |
)}
|
| 205 |
|
| 206 |
+
{isLoaded && (
|
| 207 |
+
<Check className="w-3.5 h-3.5 text-green-400" />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
)}
|
| 209 |
|
| 210 |
<button
|
scripts/download_models.py
CHANGED
|
@@ -90,6 +90,27 @@ def _ensure_hub():
|
|
| 90 |
"huggingface_hub>=0.23", "-q"])
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def download_version(version: str) -> None:
|
| 94 |
"""Download a single version (e.g. 'v1' or 'v2') from HF Hub into artifacts/."""
|
| 95 |
_ensure_hub()
|
|
@@ -100,6 +121,7 @@ def download_version(version: str) -> None:
|
|
| 100 |
allow_patterns=[f"{version}/**"],
|
| 101 |
ignore_patterns=["*.log"],
|
| 102 |
))
|
|
|
|
| 103 |
print(f"[download_models] {version}/ ready")
|
| 104 |
|
| 105 |
|
|
@@ -199,6 +221,7 @@ def download_models(version: str, model_names: list[str]) -> None:
|
|
| 199 |
allow_patterns=allow,
|
| 200 |
ignore_patterns=["*.log"],
|
| 201 |
))
|
|
|
|
| 202 |
print(f"[download_models] Selected model artifacts ready for {version}")
|
| 203 |
|
| 204 |
|
|
@@ -207,17 +230,41 @@ def download_model(version: str, model_name: str) -> None:
|
|
| 207 |
download_models(version, [model_name])
|
| 208 |
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
def download_models_meta_only(versions: list[str]) -> None:
|
| 211 |
"""Download only models.json for listed versions."""
|
| 212 |
_ensure_hub()
|
| 213 |
from huggingface_hub import snapshot_download
|
| 214 |
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 215 |
-
allow = [f"{v}/models.json" for v in versions]
|
| 216 |
print(f"[download_models] Downloading metadata files: {allow}")
|
| 217 |
snapshot_download(**_hf_kwargs(
|
| 218 |
allow_patterns=allow,
|
| 219 |
ignore_patterns=["*.log"],
|
| 220 |
))
|
|
|
|
|
|
|
| 221 |
print("[download_models] Metadata ready")
|
| 222 |
|
| 223 |
|
|
@@ -228,10 +275,18 @@ def download_all() -> None:
|
|
| 228 |
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 229 |
print(f"[download_models] Downloading all versions from {REPO_ID} -> {ARTIFACTS_DIR}")
|
| 230 |
snapshot_download(**_hf_kwargs(ignore_patterns=["*.log"]))
|
|
|
|
|
|
|
| 231 |
SENTINEL.write_text("downloaded\n")
|
| 232 |
print("[download_models] Artifacts ready")
|
| 233 |
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
def main() -> None:
|
| 236 |
import argparse
|
| 237 |
parser = argparse.ArgumentParser()
|
|
|
|
| 90 |
"huggingface_hub>=0.23", "-q"])
|
| 91 |
|
| 92 |
|
| 93 |
+
def write_datamap(version: str) -> None:
|
| 94 |
+
"""Write artifacts/<version>/datamap.json listing all locally available files."""
|
| 95 |
+
vroot = ARTIFACTS_DIR / version
|
| 96 |
+
vroot.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
items = []
|
| 98 |
+
for p in sorted(vroot.rglob("*")):
|
| 99 |
+
if not p.is_file():
|
| 100 |
+
continue
|
| 101 |
+
rel = p.relative_to(vroot).as_posix()
|
| 102 |
+
items.append({
|
| 103 |
+
"path": rel,
|
| 104 |
+
"bytes": p.stat().st_size,
|
| 105 |
+
})
|
| 106 |
+
out = {
|
| 107 |
+
"version": version,
|
| 108 |
+
"count": len(items),
|
| 109 |
+
"files": items,
|
| 110 |
+
}
|
| 111 |
+
(vroot / "datamap.json").write_text(json.dumps(out, indent=2), encoding="utf-8")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
def download_version(version: str) -> None:
|
| 115 |
"""Download a single version (e.g. 'v1' or 'v2') from HF Hub into artifacts/."""
|
| 116 |
_ensure_hub()
|
|
|
|
| 121 |
allow_patterns=[f"{version}/**"],
|
| 122 |
ignore_patterns=["*.log"],
|
| 123 |
))
|
| 124 |
+
write_datamap(version)
|
| 125 |
print(f"[download_models] {version}/ ready")
|
| 126 |
|
| 127 |
|
|
|
|
| 221 |
allow_patterns=allow,
|
| 222 |
ignore_patterns=["*.log"],
|
| 223 |
))
|
| 224 |
+
write_datamap(version)
|
| 225 |
print(f"[download_models] Selected model artifacts ready for {version}")
|
| 226 |
|
| 227 |
|
|
|
|
| 230 |
download_models(version, [model_name])
|
| 231 |
|
| 232 |
|
| 233 |
+
def download_metrics_bundle(version: str) -> None:
|
| 234 |
+
"""Download files needed by metrics page for a specific version."""
|
| 235 |
+
_ensure_hub()
|
| 236 |
+
from huggingface_hub import snapshot_download
|
| 237 |
+
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 238 |
+
allow = [
|
| 239 |
+
f"{version}/models.json",
|
| 240 |
+
f"{version}/datamap.json",
|
| 241 |
+
f"{version}/results/**",
|
| 242 |
+
f"{version}/reports/**",
|
| 243 |
+
f"{version}/features/**",
|
| 244 |
+
f"{version}/figures/**",
|
| 245 |
+
]
|
| 246 |
+
print(f"[download_models] Downloading metrics bundle for {version}")
|
| 247 |
+
snapshot_download(**_hf_kwargs(
|
| 248 |
+
allow_patterns=allow,
|
| 249 |
+
ignore_patterns=["*.log"],
|
| 250 |
+
))
|
| 251 |
+
write_datamap(version)
|
| 252 |
+
print(f"[download_models] Metrics bundle ready for {version}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
def download_models_meta_only(versions: list[str]) -> None:
|
| 256 |
"""Download only models.json for listed versions."""
|
| 257 |
_ensure_hub()
|
| 258 |
from huggingface_hub import snapshot_download
|
| 259 |
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 260 |
+
allow = [f"{v}/models.json" for v in versions] + [f"{v}/datamap.json" for v in versions]
|
| 261 |
print(f"[download_models] Downloading metadata files: {allow}")
|
| 262 |
snapshot_download(**_hf_kwargs(
|
| 263 |
allow_patterns=allow,
|
| 264 |
ignore_patterns=["*.log"],
|
| 265 |
))
|
| 266 |
+
for v in versions:
|
| 267 |
+
write_datamap(v)
|
| 268 |
print("[download_models] Metadata ready")
|
| 269 |
|
| 270 |
|
|
|
|
| 275 |
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 276 |
print(f"[download_models] Downloading all versions from {REPO_ID} -> {ARTIFACTS_DIR}")
|
| 277 |
snapshot_download(**_hf_kwargs(ignore_patterns=["*.log"]))
|
| 278 |
+
for v in ("v1", "v2", "v3"):
|
| 279 |
+
write_datamap(v)
|
| 280 |
SENTINEL.write_text("downloaded\n")
|
| 281 |
print("[download_models] Artifacts ready")
|
| 282 |
|
| 283 |
|
| 284 |
+
def ensure_metadata_first(versions: list[str] | None = None) -> None:
|
| 285 |
+
"""Guarantee models.json/datamap are present before registry usage."""
|
| 286 |
+
versions = versions or ["v1", "v2", "v3"]
|
| 287 |
+
download_models_meta_only(versions)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
def main() -> None:
|
| 291 |
import argparse
|
| 292 |
parser = argparse.ArgumentParser()
|