Merry99's picture
prevent hugging face spaces pause
ece3e89
"""FastAPI μ•±: μˆ˜λ™ ν•™μŠ΅ 및 Hugging Face μ—…λ‘œλ“œ 트리거"""
from __future__ import annotations
import json
import os
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional
import schedule
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from huggingface_hub import HfApi, hf_hub_download
try:
from huggingface_hub.utils import HfHubHTTPError
except ImportError: # pragma: no cover
HfHubHTTPError = Exception # type: ignore
from pydantic import BaseModel
from train_scheduler import TrainingScheduler
app = FastAPI(
title="MuscleCare Train Scheduler API",
description="μˆ˜λ™μœΌλ‘œ λͺ¨λΈ ν•™μŠ΅ 및 Hugging Face μ—…λ‘œλ“œλ₯Ό νŠΈλ¦¬κ±°ν•©λ‹ˆλ‹€.",
)
_scheduler = TrainingScheduler()
class TrainResponse(BaseModel):
status: str
new_data_count: int
model_path: Optional[str] = None
hub_url: Optional[str] = None
model_version: Optional[int] = None
message: str
@app.on_event("startup")
def startup_training() -> None:
"""μ„œλ²„ μ‹œμž‘ μ‹œ μžλ™μœΌλ‘œ λͺ¨λΈ ν•™μŠ΅μ„ μ‹€ν–‰ν•©λ‹ˆλ‹€."""
try:
print("πŸš€ μ„œλ²„ μ‹œμž‘: μžλ™ λͺ¨λΈ ν•™μŠ΅μ„ μ‹œμž‘ν•©λ‹ˆλ‹€...")
result = _scheduler.run_scheduled_training()
if result["status"] == "trained":
print(f"βœ… μ„œλ²„ μ‹œμž‘ μ‹œ ν•™μŠ΅ μ™„λ£Œ: {result['new_data_count']}개 λ°μ΄ν„°λ‘œ ν•™μŠ΅λ¨")
else:
print(f"ℹ️ μ„œλ²„ μ‹œμž‘ μ‹œ ν•™μŠ΅ κ±΄λ„ˆλœ€: {result.get('message', 'μƒˆλ‘œμš΄ 데이터 μ—†μŒ')}")
except Exception as exc:
print(f"⚠️ μ„œλ²„ μ‹œμž‘ μ‹œ ν•™μŠ΅ μ‹€νŒ¨: {exc}")
# κΈ°μ‘΄ μŠ€μΌ€μ€„λ§ μ„€μ •
schedule.clear()
schedule.every().sunday.at("00:00").do(_scheduler.run_scheduled_training)
def _run_schedule() -> None:
while True:
schedule.run_pending()
time.sleep(60)
threading.Thread(target=_run_schedule, daemon=True).start()
@app.head("/health")
async def health_head():
return None # HEADλŠ” λ°”λ””κ°€ ν•„μš” μ—†μœΌλ―€λ‘œ None λ°˜ν™˜
@app.get("/health")
def health_check() -> dict:
return {"status": "ok"}
@app.get("/")
def root() -> dict:
return {
"message": "MuscleCare Train Scheduler APIκ°€ μ‹€ν–‰ μ€‘μž…λ‹ˆλ‹€.",
"endpoints": {
"health": "/health",
"trigger": "/trigger",
},
"docs": "/docs",
}
def _upload_to_hub(model_path: str) -> Optional[str]:
token = os.getenv("HF_E2E_MODEL_TOKEN")
repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
if not token or not repo_id:
raise HTTPException(
status_code=400,
detail="ν™˜κ²½ λ³€μˆ˜ HF_E2E_MODEL_TOKEN / HF_E2E_MODEL_REPO_IDκ°€ μ„€μ •λ˜μ–΄ μžˆμ§€ μ•ŠμŠ΅λ‹ˆλ‹€.",
)
path = Path(model_path)
if not path.exists():
raise HTTPException(status_code=404, detail=f"λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€: {model_path}")
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True)
api.upload_file(
path_or_fileobj=path,
path_in_repo=path.name,
repo_id=repo_id,
repo_type="model",
commit_message="Manual scheduler trigger upload",
)
return f"https://huggingface.co/{repo_id}"
# TODO: include version info in response body
@app.get("/model")
@app.get("/model/{version:int}")
def download_model(
version: Optional[int] = None,
filename: Optional[str] = None
) -> FileResponse:
repo_id = os.getenv("HF_E2E_MODEL_REPO_ID")
token = os.getenv("HF_E2E_MODEL_TOKEN")
default_filename = os.getenv("HF_E2E_MODEL_FILE", "cnn_gru_fatigue.tflite")
if not repo_id:
raise HTTPException(
status_code=400,
detail="ν™˜κ²½ λ³€μˆ˜ HF_E2E_MODEL_REPO_IDκ°€ μ„€μ •λ˜μ–΄ μžˆμ§€ μ•ŠμŠ΅λ‹ˆλ‹€."
)
current_state = _scheduler.load_training_state()
current_version = int(current_state.get("model_version", 0) or 0)
try:
if not version:
target_filename = filename or default_filename
local_path = hf_hub_download(
repo_id=repo_id,
filename=target_filename,
repo_type="model",
token=token,
local_dir="./model_cache",
local_dir_use_symlinks=False,
)
actual_version = current_version
else:
if version > current_version:
raise HTTPException(
status_code=404,
detail=f"ν˜„μž¬ λͺ¨λΈ 버전은 {current_version}μž…λ‹ˆλ‹€. 버전 {version}은 μ‘΄μž¬ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€."
)
manifest_path = hf_hub_download(
repo_id=repo_id,
filename="model_versions.json",
repo_type="model",
token=token,
local_dir="./model_cache",
local_dir_use_symlinks=False,
)
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
version_entry = next(
(entry for entry in manifest if entry.get("version") == version),
None
)
if version_entry is None:
raise HTTPException(
status_code=404,
detail=f"버전 {version}에 ν•΄λ‹Ήν•˜λŠ” λͺ¨λΈμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€."
)
target_filename = filename or version_entry.get("filename")
target_revision = version_entry.get("commit")
if not target_filename or not target_revision:
raise HTTPException(
status_code=500,
detail=f"버전 {version} 메타데이터가 μ˜¬λ°”λ₯΄μ§€ μ•ŠμŠ΅λ‹ˆλ‹€."
)
local_path = hf_hub_download(
repo_id=repo_id,
filename=target_filename,
repo_type="model",
token=token,
local_dir="./model_cache",
local_dir_use_symlinks=False,
revision=target_revision,
)
actual_version = version
except Exception as exc:
status = getattr(getattr(exc, "response", None), "status_code", None)
if status == 404:
raise HTTPException(
status_code=404,
detail="ν—ˆκΉ…νŽ˜μ΄μŠ€μ—μ„œ μ§€μ •ν•œ λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€."
) from exc
raise HTTPException(
status_code=500,
detail=f"Hugging Face Hub λ‹€μš΄λ‘œλ“œ μ‹€νŒ¨: {exc}"
) from exc
response = FileResponse(
path=local_path,
filename=Path(target_filename).name,
media_type="application/octet-stream"
)
response.headers["X-Model-Version"] = str(actual_version)
response.headers["X-Model-Filename"] = Path(target_filename).name
return response
class ResetStateResponse(BaseModel):
status: str
state: Dict[str, Any]
@app.post("/state/reset", response_model=ResetStateResponse)
def reset_training_state() -> ResetStateResponse:
try:
state = _scheduler.reset_training_state()
return ResetStateResponse(
status="reset",
state=state,
)
except Exception as exc: # pylint: disable=broad-except
raise HTTPException(status_code=500, detail=f"ν•™μŠ΅ μƒνƒœ μ΄ˆκΈ°ν™”μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€: {exc}") from exc
@app.post("/trigger", response_model=TrainResponse)
def trigger_training(upload: bool = True) -> TrainResponse:
try:
result = _scheduler.run_scheduled_training()
except Exception as exc: # pylint: disable=broad-except
raise HTTPException(status_code=500, detail=f"ν•™μŠ΅ μ‹€ν–‰ 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {exc}") from exc
message = "μƒˆλ‘œμš΄ 데이터가 μ—†μ–΄ ν•™μŠ΅μ„ κ±΄λ„ˆλœλ‹ˆλ‹€."
hub_url = None
if result["status"] == "trained":
message = "λͺ¨λΈ ν•™μŠ΅μ΄ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€."
model_path = result.get("model_path")
if upload and model_path:
try:
hub_url = _upload_to_hub(model_path)
message = "λͺ¨λΈ ν•™μŠ΅ 및 μ—…λ‘œλ“œκ°€ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€."
except HTTPException:
raise
except Exception as exc: # pylint: disable=broad-except
raise HTTPException(status_code=500, detail=f"Hugging Face μ—…λ‘œλ“œ μ‹€νŒ¨: {exc}") from exc
return TrainResponse(
status=result["status"],
new_data_count=result["new_data_count"],
model_path=result.get("model_path"),
hub_url=hub_url,
message=message,
)
__all__ = ["app"]