Spaces:
Running
Running
| """FastAPI μ±: μλ νμ΅ λ° Hugging Face μ λ‘λ νΈλ¦¬κ±°""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import schedule | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from huggingface_hub import HfApi, hf_hub_download | |
| try: | |
| from huggingface_hub.utils import HfHubHTTPError | |
| except ImportError: # pragma: no cover | |
| HfHubHTTPError = Exception # type: ignore | |
| from pydantic import BaseModel | |
| from train_scheduler import TrainingScheduler | |
| app = FastAPI( | |
| title="MuscleCare Train Scheduler API", | |
| description="μλμΌλ‘ λͺ¨λΈ νμ΅ λ° Hugging Face μ λ‘λλ₯Ό νΈλ¦¬κ±°ν©λλ€.", | |
| ) | |
| _scheduler = TrainingScheduler() | |
| class TrainResponse(BaseModel): | |
| status: str | |
| new_data_count: int | |
| model_path: Optional[str] = None | |
| hub_url: Optional[str] = None | |
| model_version: Optional[int] = None | |
| message: str | |
| def startup_training() -> None: | |
| """μλ² μμ μ μλμΌλ‘ λͺ¨λΈ νμ΅μ μ€νν©λλ€.""" | |
| try: | |
| print("π μλ² μμ: μλ λͺ¨λΈ νμ΅μ μμν©λλ€...") | |
| result = _scheduler.run_scheduled_training() | |
| if result["status"] == "trained": | |
| print(f"β μλ² μμ μ νμ΅ μλ£: {result['new_data_count']}κ° λ°μ΄ν°λ‘ νμ΅λ¨") | |
| else: | |
| print(f"βΉοΈ μλ² μμ μ νμ΅ κ±΄λλ: {result.get('message', 'μλ‘μ΄ λ°μ΄ν° μμ')}") | |
| except Exception as exc: | |
| print(f"β οΈ μλ² μμ μ νμ΅ μ€ν¨: {exc}") | |
| # κΈ°μ‘΄ μ€μΌμ€λ§ μ€μ | |
| schedule.clear() | |
| schedule.every().sunday.at("00:00").do(_scheduler.run_scheduled_training) | |
| def _run_schedule() -> None: | |
| while True: | |
| schedule.run_pending() | |
| time.sleep(60) | |
| threading.Thread(target=_run_schedule, daemon=True).start() | |
| async def health_head(): | |
| return None # HEADλ λ°λκ° νμ μμΌλ―λ‘ None λ°ν | |
| def health_check() -> dict: | |
| return {"status": "ok"} | |
| def root() -> dict: | |
| return { | |
| "message": "MuscleCare Train Scheduler APIκ° μ€ν μ€μ λλ€.", | |
| "endpoints": { | |
| "health": "/health", | |
| "trigger": "/trigger", | |
| }, | |
| "docs": "/docs", | |
| } | |
| def _upload_to_hub(model_path: str) -> Optional[str]: | |
| token = os.getenv("HF_E2E_MODEL_TOKEN") | |
| repo_id = os.getenv("HF_E2E_MODEL_REPO_ID") | |
| if not token or not repo_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="νκ²½ λ³μ HF_E2E_MODEL_TOKEN / HF_E2E_MODEL_REPO_IDκ° μ€μ λμ΄ μμ§ μμ΅λλ€.", | |
| ) | |
| path = Path(model_path) | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail=f"λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€: {model_path}") | |
| api = HfApi(token=token) | |
| api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True) | |
| api.upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=path.name, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message="Manual scheduler trigger upload", | |
| ) | |
| return f"https://huggingface.co/{repo_id}" | |
| # TODO: include version info in response body | |
| def download_model( | |
| version: Optional[int] = None, | |
| filename: Optional[str] = None | |
| ) -> FileResponse: | |
| repo_id = os.getenv("HF_E2E_MODEL_REPO_ID") | |
| token = os.getenv("HF_E2E_MODEL_TOKEN") | |
| default_filename = os.getenv("HF_E2E_MODEL_FILE", "cnn_gru_fatigue.tflite") | |
| if not repo_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="νκ²½ λ³μ HF_E2E_MODEL_REPO_IDκ° μ€μ λμ΄ μμ§ μμ΅λλ€." | |
| ) | |
| current_state = _scheduler.load_training_state() | |
| current_version = int(current_state.get("model_version", 0) or 0) | |
| try: | |
| if not version: | |
| target_filename = filename or default_filename | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=target_filename, | |
| repo_type="model", | |
| token=token, | |
| local_dir="./model_cache", | |
| local_dir_use_symlinks=False, | |
| ) | |
| actual_version = current_version | |
| else: | |
| if version > current_version: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"νμ¬ λͺ¨λΈ λ²μ μ {current_version}μ λλ€. λ²μ {version}μ μ‘΄μ¬νμ§ μμ΅λλ€." | |
| ) | |
| manifest_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="model_versions.json", | |
| repo_type="model", | |
| token=token, | |
| local_dir="./model_cache", | |
| local_dir_use_symlinks=False, | |
| ) | |
| with open(manifest_path, "r", encoding="utf-8") as f: | |
| manifest = json.load(f) | |
| version_entry = next( | |
| (entry for entry in manifest if entry.get("version") == version), | |
| None | |
| ) | |
| if version_entry is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"λ²μ {version}μ ν΄λΉνλ λͺ¨λΈμ μ°Ύμ μ μμ΅λλ€." | |
| ) | |
| target_filename = filename or version_entry.get("filename") | |
| target_revision = version_entry.get("commit") | |
| if not target_filename or not target_revision: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"λ²μ {version} λ©νλ°μ΄ν°κ° μ¬λ°λ₯΄μ§ μμ΅λλ€." | |
| ) | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=target_filename, | |
| repo_type="model", | |
| token=token, | |
| local_dir="./model_cache", | |
| local_dir_use_symlinks=False, | |
| revision=target_revision, | |
| ) | |
| actual_version = version | |
| except Exception as exc: | |
| status = getattr(getattr(exc, "response", None), "status_code", None) | |
| if status == 404: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="νκΉ νμ΄μ€μμ μ§μ ν λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€." | |
| ) from exc | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Hugging Face Hub λ€μ΄λ‘λ μ€ν¨: {exc}" | |
| ) from exc | |
| response = FileResponse( | |
| path=local_path, | |
| filename=Path(target_filename).name, | |
| media_type="application/octet-stream" | |
| ) | |
| response.headers["X-Model-Version"] = str(actual_version) | |
| response.headers["X-Model-Filename"] = Path(target_filename).name | |
| return response | |
| class ResetStateResponse(BaseModel): | |
| status: str | |
| state: Dict[str, Any] | |
| def reset_training_state() -> ResetStateResponse: | |
| try: | |
| state = _scheduler.reset_training_state() | |
| return ResetStateResponse( | |
| status="reset", | |
| state=state, | |
| ) | |
| except Exception as exc: # pylint: disable=broad-except | |
| raise HTTPException(status_code=500, detail=f"νμ΅ μν μ΄κΈ°νμ μ€ν¨νμ΅λλ€: {exc}") from exc | |
| def trigger_training(upload: bool = True) -> TrainResponse: | |
| try: | |
| result = _scheduler.run_scheduled_training() | |
| except Exception as exc: # pylint: disable=broad-except | |
| raise HTTPException(status_code=500, detail=f"νμ΅ μ€ν μ€ μ€λ₯κ° λ°μνμ΅λλ€: {exc}") from exc | |
| message = "μλ‘μ΄ λ°μ΄ν°κ° μμ΄ νμ΅μ 건λλλλ€." | |
| hub_url = None | |
| if result["status"] == "trained": | |
| message = "λͺ¨λΈ νμ΅μ΄ μλ£λμμ΅λλ€." | |
| model_path = result.get("model_path") | |
| if upload and model_path: | |
| try: | |
| hub_url = _upload_to_hub(model_path) | |
| message = "λͺ¨λΈ νμ΅ λ° μ λ‘λκ° μλ£λμμ΅λλ€." | |
| except HTTPException: | |
| raise | |
| except Exception as exc: # pylint: disable=broad-except | |
| raise HTTPException(status_code=500, detail=f"Hugging Face μ λ‘λ μ€ν¨: {exc}") from exc | |
| return TrainResponse( | |
| status=result["status"], | |
| new_data_count=result["new_data_count"], | |
| model_path=result.get("model_path"), | |
| hub_url=hub_url, | |
| message=message, | |
| ) | |
| __all__ = ["app"] | |