Sanuka0523 commited on
Commit
11e9a40
·
1 Parent(s): a86336f

Deploy cardiac monitor FastAPI backend

Browse files

- FastAPI with JWT auth, health profiles, vitals, ML predictions
- ECGFounder + XGBoost ensemble for cardiac risk assessment
- Dockerfile downloads models from GitHub LFS during build
- CPU-only PyTorch for optimized container size

.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ .env
4
+ .env.*
5
+ .git
6
+ .gitignore
7
+ *.md
8
+ .venv
9
+ venv
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install CPU-only PyTorch first (saves ~1.5GB vs full torch)
6
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
7
+
8
+ # Install remaining dependencies
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ # Copy application code
13
+ COPY app/ app/
14
+ COPY ml_src/ ml_src/
15
+
16
+ # Download ML models from GitHub LFS (avoids LFS issues on HF Spaces)
17
+ RUN mkdir -p ml_models && \
18
+ apt-get update && apt-get install -y --no-install-recommends curl && \
19
+ curl -L -o ml_models/ecgfounder_best.pt \
20
+ "https://media.githubusercontent.com/media/Sanuka23/cardiac-monitor/main/backend/ml_models/ecgfounder_best.pt" && \
21
+ curl -L -o ml_models/xgboost_cardiac.joblib \
22
+ "https://media.githubusercontent.com/media/Sanuka23/cardiac-monitor/main/backend/ml_models/xgboost_cardiac.joblib" && \
23
+ apt-get purge -y curl && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* && \
24
+ echo "Model sizes:" && ls -lh ml_models/
25
+
26
+ EXPOSE 7860
27
+
28
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,23 @@
1
  ---
2
  title: Cardiac Monitor API
3
- emoji: 💻
4
  colorFrom: red
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
- short_description: FastAPI backend for ESP32 Heart Rate, SpO2 & ECG cardiac mon
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Cardiac Monitor API
3
+ emoji: ❤️
4
  colorFrom: red
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ app_port: 7860
10
+ short_description: Cardiac monitoring API with ML risk prediction
11
  ---
12
 
13
+ # Cardiac Monitor API
14
+
15
+ FastAPI backend for ESP32 Heart Rate, SpO2 & ECG cardiac monitoring system with ML-based risk prediction.
16
+
17
+ ## Endpoints
18
+
19
+ - `GET /api/v1/health` — Health check
20
+ - `POST /api/v1/auth/register` — Register
21
+ - `POST /api/v1/auth/login` — Login (JWT)
22
+ - `GET /api/v1/vitals/{device_id}` — Vitals history
23
+ - `GET /api/v1/predictions/{device_id}` — Risk predictions
app/__init__.py ADDED
File without changes
app/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+
3
+
4
+ class Settings(BaseSettings):
5
+ MONGODB_URI: str = "mongodb://localhost:27017"
6
+ DATABASE_NAME: str = "cardiac_monitor"
7
+ JWT_SECRET: str = "changeme"
8
+ JWT_ALGORITHM: str = "HS256"
9
+ JWT_EXPIRY_HOURS: int = 24
10
+ API_KEY: str = "dev-api-key"
11
+
12
+ class Config:
13
+ env_file = ".env"
14
+
15
+
16
+ settings = Settings()
app/database.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from motor.motor_asyncio import AsyncIOMotorClient
2
+
3
+ client: AsyncIOMotorClient = None
4
+ db = None
5
+
6
+
7
+ async def connect_db(uri: str, db_name: str):
8
+ global client, db
9
+ client = AsyncIOMotorClient(uri, serverSelectionTimeoutMS=5000)
10
+ db = client[db_name]
11
+
12
+ # Create indexes (non-blocking — if DB is unreachable, server still starts)
13
+ try:
14
+ await db.users.create_index("email", unique=True)
15
+ await db.devices.create_index("device_id", unique=True)
16
+ await db.vitals.create_index([("device_id", 1), ("timestamp", -1)])
17
+ await db.predictions.create_index([("device_id", 1), ("created_at", -1)])
18
+ print("[DB] Connected and indexes created.")
19
+ except Exception as e:
20
+ print(f"[DB] Warning: Could not create indexes: {e}")
21
+ print("[DB] Server will start but DB operations may fail until MongoDB is available.")
22
+
23
+
24
+ async def close_db():
25
+ global client
26
+ if client:
27
+ client.close()
28
+
29
+
30
+ def get_db():
31
+ return db
app/main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+
6
+ from app.config import settings
7
+ from app.database import connect_db, close_db
8
+ from app.routes import vitals, auth, devices, health, predictions
9
+
10
+
11
+ @asynccontextmanager
12
+ async def lifespan(app: FastAPI):
13
+ await connect_db(settings.MONGODB_URI, settings.DATABASE_NAME)
14
+
15
+ # Load ML models (non-blocking — server starts even if models aren't ready)
16
+ try:
17
+ from app.services.ml_service import load_models
18
+ load_models()
19
+ except Exception as e:
20
+ print(f"[ML] Could not load models: {e}")
21
+ print("[ML] Server will run without predictions until models are available.")
22
+
23
+ yield
24
+ await close_db()
25
+
26
+
27
+ app = FastAPI(
28
+ title="Cardiac Monitor API",
29
+ version="1.0.0",
30
+ description="Backend for ESP32 Heart Rate, SpO2 & ECG Monitor",
31
+ lifespan=lifespan,
32
+ )
33
+
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ app.include_router(health.router, prefix="/api/v1", tags=["health"])
43
+ app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
44
+ app.include_router(vitals.router, prefix="/api/v1/vitals", tags=["vitals"])
45
+ app.include_router(predictions.router, prefix="/api/v1/predictions", tags=["predictions"])
46
+ app.include_router(devices.router, prefix="/api/v1/devices", tags=["devices"])
app/middleware/__init__.py ADDED
File without changes
app/middleware/auth.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+
3
+ import bcrypt
4
+ from bson import ObjectId
5
+ from fastapi import Depends, HTTPException, Header
6
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
+ from jose import jwt, JWTError
8
+
9
+ from app.config import settings
10
+ from app.database import get_db
11
+
12
+ security = HTTPBearer()
13
+
14
+
15
+ def hash_password(password: str) -> str:
16
+ return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
17
+
18
+
19
+ def verify_password(plain: str, hashed: str) -> bool:
20
+ return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
21
+
22
+
23
+ def create_access_token(user_id: str) -> str:
24
+ expire = datetime.utcnow() + timedelta(hours=settings.JWT_EXPIRY_HOURS)
25
+ payload = {"sub": user_id, "exp": expire}
26
+ return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
27
+
28
+
29
+ async def verify_api_key(x_api_key: str = Header(...)):
30
+ if x_api_key != settings.API_KEY:
31
+ raise HTTPException(status_code=401, detail="Invalid API key")
32
+ return x_api_key
33
+
34
+
35
+ async def get_current_user(
36
+ credentials: HTTPAuthorizationCredentials = Depends(security),
37
+ ):
38
+ try:
39
+ payload = jwt.decode(
40
+ credentials.credentials,
41
+ settings.JWT_SECRET,
42
+ algorithms=[settings.JWT_ALGORITHM],
43
+ )
44
+ user_id = payload.get("sub")
45
+ if user_id is None:
46
+ raise HTTPException(status_code=401, detail="Invalid token")
47
+ except JWTError:
48
+ raise HTTPException(status_code=401, detail="Invalid token")
49
+
50
+ db = get_db()
51
+ user = await db.users.find_one({"_id": ObjectId(user_id)})
52
+ if user is None:
53
+ raise HTTPException(status_code=401, detail="User not found")
54
+ return user
app/models/__init__.py ADDED
File without changes
app/models/device.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional
3
+ from datetime import datetime
4
+
5
+
6
+ class DeviceRegister(BaseModel):
7
+ device_id: str = Field(..., min_length=1, max_length=50)
8
+ name: Optional[str] = Field(None, max_length=100)
9
+
10
+
11
+ class DeviceResponse(BaseModel):
12
+ device_id: str
13
+ name: Optional[str] = None
14
+ owner_user_id: Optional[str] = None
15
+ last_seen: Optional[datetime] = None
16
+ registered_at: datetime
app/models/prediction.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional
3
+ from datetime import datetime
4
+
5
+
6
+ class PredictionResponse(BaseModel):
7
+ id: str
8
+ vitals_id: str
9
+ device_id: str
10
+ risk_score: float = Field(..., ge=0.0, le=1.0)
11
+ risk_label: str
12
+ confidence: float = Field(..., ge=0.0, le=1.0)
13
+ features: dict
14
+ model_version: str
15
+ created_at: datetime
app/models/user.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, Field
2
+ from typing import List, Optional
3
+ from datetime import datetime
4
+
5
+
6
+ class HealthProfile(BaseModel):
7
+ age: Optional[int] = Field(None, ge=1, le=120)
8
+ sex: Optional[str] = Field(None, pattern="^(male|female|other)$")
9
+ height_cm: Optional[float] = Field(None, ge=50, le=300)
10
+ weight_kg: Optional[float] = Field(None, ge=10, le=500)
11
+ is_diabetic: bool = False
12
+ is_hypertensive: bool = False
13
+ is_smoker: bool = False
14
+ family_history: bool = False
15
+ known_conditions: List[str] = Field(default_factory=list)
16
+ medications: List[str] = Field(default_factory=list)
17
+
18
+
19
+ class UserRegister(BaseModel):
20
+ email: str = Field(..., min_length=5, max_length=100)
21
+ password: str = Field(..., min_length=6, max_length=100)
22
+ name: str = Field(..., min_length=1, max_length=100)
23
+
24
+
25
+ class UserLogin(BaseModel):
26
+ email: str
27
+ password: str
28
+
29
+
30
+ class UserResponse(BaseModel):
31
+ id: str
32
+ email: str
33
+ name: str
34
+ device_ids: List[str] = Field(default_factory=list)
35
+ profile: Optional[HealthProfile] = None
36
+ created_at: datetime
37
+
38
+
39
+ class TokenResponse(BaseModel):
40
+ access_token: str
41
+ token_type: str = "bearer"
app/models/vitals.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+ from datetime import datetime
4
+
5
+
6
+ class VitalsCreate(BaseModel):
7
+ device_id: str = Field(..., min_length=1, max_length=50)
8
+ timestamp: int = Field(..., description="Unix epoch seconds from ESP32")
9
+ window_ms: int = Field(default=10000, ge=1000, le=60000)
10
+ sample_rate_hz: int = Field(default=100, ge=50, le=1000)
11
+ heart_rate_bpm: float = Field(..., ge=0, le=300)
12
+ spo2_percent: int = Field(..., ge=0, le=100)
13
+ ecg_lead_off: bool = Field(default=False)
14
+ ecg_samples: List[int] = Field(..., min_length=100, max_length=6000)
15
+ beat_timestamps_ms: List[int] = Field(default_factory=list)
16
+
17
+
18
+ class VitalsResponse(BaseModel):
19
+ id: str
20
+ device_id: str
21
+ timestamp: datetime
22
+ heart_rate_bpm: float
23
+ spo2_percent: int
24
+ ecg_lead_off: bool
25
+ sample_count: int
26
+ prediction: Optional[dict] = None
27
+ created_at: datetime
28
+
29
+
30
+ class VitalsListResponse(BaseModel):
31
+ vitals: List[VitalsResponse]
32
+ total: int
app/routes/__init__.py ADDED
File without changes
app/routes/auth.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.database import get_db
6
+ from app.middleware.auth import (
7
+ hash_password,
8
+ verify_password,
9
+ create_access_token,
10
+ get_current_user,
11
+ )
12
+ from app.models.user import (
13
+ UserRegister,
14
+ UserLogin,
15
+ UserResponse,
16
+ TokenResponse,
17
+ HealthProfile,
18
+ )
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ @router.post("/register", response_model=TokenResponse)
24
+ async def register(data: UserRegister):
25
+ db = get_db()
26
+
27
+ existing = await db.users.find_one({"email": data.email})
28
+ if existing:
29
+ raise HTTPException(status_code=400, detail="Email already registered")
30
+
31
+ user_doc = {
32
+ "email": data.email,
33
+ "password_hash": hash_password(data.password),
34
+ "name": data.name,
35
+ "device_ids": [],
36
+ "profile": None,
37
+ "created_at": datetime.utcnow(),
38
+ }
39
+ result = await db.users.insert_one(user_doc)
40
+ token = create_access_token(str(result.inserted_id))
41
+ return TokenResponse(access_token=token)
42
+
43
+
44
+ @router.post("/login", response_model=TokenResponse)
45
+ async def login(data: UserLogin):
46
+ db = get_db()
47
+
48
+ user = await db.users.find_one({"email": data.email})
49
+ if not user or not verify_password(data.password, user["password_hash"]):
50
+ raise HTTPException(status_code=401, detail="Invalid email or password")
51
+
52
+ token = create_access_token(str(user["_id"]))
53
+ return TokenResponse(access_token=token)
54
+
55
+
56
+ @router.get("/me", response_model=UserResponse)
57
+ async def get_me(user=Depends(get_current_user)):
58
+ return UserResponse(
59
+ id=str(user["_id"]),
60
+ email=user["email"],
61
+ name=user["name"],
62
+ device_ids=user.get("device_ids", []),
63
+ profile=user.get("profile"),
64
+ created_at=user["created_at"],
65
+ )
66
+
67
+
68
+ @router.put("/profile", response_model=UserResponse)
69
+ async def update_profile(profile: HealthProfile, user=Depends(get_current_user)):
70
+ db = get_db()
71
+
72
+ await db.users.update_one(
73
+ {"_id": user["_id"]},
74
+ {"$set": {"profile": profile.model_dump()}},
75
+ )
76
+
77
+ user["profile"] = profile.model_dump()
78
+ return UserResponse(
79
+ id=str(user["_id"]),
80
+ email=user["email"],
81
+ name=user["name"],
82
+ device_ids=user.get("device_ids", []),
83
+ profile=user["profile"],
84
+ created_at=user["created_at"],
85
+ )
app/routes/devices.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+
5
+ from app.database import get_db
6
+ from app.middleware.auth import get_current_user
7
+ from app.models.device import DeviceRegister, DeviceResponse
8
+
9
+ router = APIRouter()
10
+
11
+
12
+ @router.post("/register", response_model=DeviceResponse)
13
+ async def register_device(data: DeviceRegister, user=Depends(get_current_user)):
14
+ db = get_db()
15
+
16
+ existing = await db.devices.find_one({"device_id": data.device_id})
17
+ if existing:
18
+ # If device already exists, link to this user if not already linked
19
+ if existing.get("owner_user_id") and str(existing["owner_user_id"]) != str(
20
+ user["_id"]
21
+ ):
22
+ raise HTTPException(
23
+ status_code=400, detail="Device already registered to another user"
24
+ )
25
+ await db.devices.update_one(
26
+ {"device_id": data.device_id},
27
+ {
28
+ "$set": {
29
+ "owner_user_id": str(user["_id"]),
30
+ "name": data.name or existing.get("name"),
31
+ }
32
+ },
33
+ )
34
+ existing["owner_user_id"] = str(user["_id"])
35
+ if data.name:
36
+ existing["name"] = data.name
37
+ doc = existing
38
+ else:
39
+ doc = {
40
+ "device_id": data.device_id,
41
+ "name": data.name,
42
+ "owner_user_id": str(user["_id"]),
43
+ "last_seen": None,
44
+ "registered_at": datetime.utcnow(),
45
+ }
46
+ await db.devices.insert_one(doc)
47
+
48
+ # Add device_id to user's device list
49
+ await db.users.update_one(
50
+ {"_id": user["_id"]},
51
+ {"$addToSet": {"device_ids": data.device_id}},
52
+ )
53
+
54
+ return DeviceResponse(
55
+ device_id=doc["device_id"],
56
+ name=doc.get("name"),
57
+ owner_user_id=doc.get("owner_user_id"),
58
+ last_seen=doc.get("last_seen"),
59
+ registered_at=doc.get("registered_at", datetime.utcnow()),
60
+ )
61
+
62
+
63
+ @router.get("", response_model=list[DeviceResponse])
64
+ async def list_devices(user=Depends(get_current_user)):
65
+ db = get_db()
66
+
67
+ device_ids = user.get("device_ids", [])
68
+ if not device_ids:
69
+ return []
70
+
71
+ cursor = db.devices.find({"device_id": {"$in": device_ids}})
72
+ docs = await cursor.to_list(length=100)
73
+
74
+ return [
75
+ DeviceResponse(
76
+ device_id=doc["device_id"],
77
+ name=doc.get("name"),
78
+ owner_user_id=doc.get("owner_user_id"),
79
+ last_seen=doc.get("last_seen"),
80
+ registered_at=doc.get("registered_at", datetime.utcnow()),
81
+ )
82
+ for doc in docs
83
+ ]
app/routes/health.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ from app.database import get_db
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.get("/health")
9
+ async def health_check():
10
+ db = get_db()
11
+ try:
12
+ await db.command("ping")
13
+ db_status = "connected"
14
+ except Exception:
15
+ db_status = "disconnected"
16
+
17
+ return {"status": "ok", "db": db_status, "version": "1.0.0"}
app/routes/predictions.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Query
2
+
3
+ from app.database import get_db
4
+ from app.middleware.auth import get_current_user
5
+ from app.models.prediction import PredictionResponse
6
+
7
+ router = APIRouter()
8
+
9
+
10
+ @router.get("/{device_id}/latest", response_model=PredictionResponse)
11
+ async def get_latest_prediction(device_id: str, _=Depends(get_current_user)):
12
+ db = get_db()
13
+
14
+ doc = await db.predictions.find_one(
15
+ {"device_id": device_id},
16
+ sort=[("created_at", -1)],
17
+ )
18
+ if not doc:
19
+ raise HTTPException(
20
+ status_code=404, detail="No predictions found for this device"
21
+ )
22
+
23
+ return PredictionResponse(
24
+ id=str(doc["_id"]),
25
+ vitals_id=doc["vitals_id"],
26
+ device_id=doc["device_id"],
27
+ risk_score=doc["risk_score"],
28
+ risk_label=doc["risk_label"],
29
+ confidence=doc["confidence"],
30
+ features=doc.get("features", {}),
31
+ model_version=doc.get("model_version", "none"),
32
+ created_at=doc["created_at"],
33
+ )
34
+
35
+
36
+ @router.get("/{device_id}")
37
+ async def get_prediction_history(
38
+ device_id: str,
39
+ limit: int = Query(default=50, ge=1, le=500),
40
+ offset: int = Query(default=0, ge=0),
41
+ _=Depends(get_current_user),
42
+ ):
43
+ db = get_db()
44
+
45
+ cursor = (
46
+ db.predictions.find({"device_id": device_id})
47
+ .sort("created_at", -1)
48
+ .skip(offset)
49
+ .limit(limit)
50
+ )
51
+ docs = await cursor.to_list(length=limit)
52
+
53
+ results = [
54
+ PredictionResponse(
55
+ id=str(doc["_id"]),
56
+ vitals_id=doc["vitals_id"],
57
+ device_id=doc["device_id"],
58
+ risk_score=doc["risk_score"],
59
+ risk_label=doc["risk_label"],
60
+ confidence=doc["confidence"],
61
+ features=doc.get("features", {}),
62
+ model_version=doc.get("model_version", "none"),
63
+ created_at=doc["created_at"],
64
+ )
65
+ for doc in docs
66
+ ]
67
+
68
+ total = await db.predictions.count_documents({"device_id": device_id})
69
+ return {"predictions": results, "total": total}
app/routes/vitals.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from bson import ObjectId
4
+ from fastapi import APIRouter, Depends, HTTPException, Query
5
+
6
+ from app.database import get_db
7
+ from app.middleware.auth import verify_api_key, get_current_user
8
+ from app.models.vitals import VitalsCreate, VitalsResponse, VitalsListResponse
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ def _vitals_doc_to_response(doc: dict, prediction: dict = None) -> VitalsResponse:
14
+ return VitalsResponse(
15
+ id=str(doc["_id"]),
16
+ device_id=doc["device_id"],
17
+ timestamp=doc["timestamp"],
18
+ heart_rate_bpm=doc["heart_rate_bpm"],
19
+ spo2_percent=doc["spo2_percent"],
20
+ ecg_lead_off=doc["ecg_lead_off"],
21
+ sample_count=len(doc.get("ecg_samples", [])),
22
+ prediction=prediction,
23
+ created_at=doc["created_at"],
24
+ )
25
+
26
+
27
+ @router.post("", response_model=VitalsResponse)
28
+ async def upload_vitals(data: VitalsCreate, _=Depends(verify_api_key)):
29
+ db = get_db()
30
+
31
+ vitals_doc = {
32
+ "device_id": data.device_id,
33
+ "timestamp": datetime.utcfromtimestamp(data.timestamp),
34
+ "window_ms": data.window_ms,
35
+ "sample_rate_hz": data.sample_rate_hz,
36
+ "heart_rate_bpm": data.heart_rate_bpm,
37
+ "spo2_percent": data.spo2_percent,
38
+ "ecg_lead_off": data.ecg_lead_off,
39
+ "ecg_samples": data.ecg_samples,
40
+ "beat_timestamps_ms": data.beat_timestamps_ms,
41
+ "created_at": datetime.utcnow(),
42
+ }
43
+ result = await db.vitals.insert_one(vitals_doc)
44
+ vitals_doc["_id"] = result.inserted_id
45
+
46
+ # Update device last_seen
47
+ await db.devices.update_one(
48
+ {"device_id": data.device_id},
49
+ {"$set": {"last_seen": datetime.utcnow()}},
50
+ )
51
+
52
+ # Run ML prediction if models are available
53
+ prediction = None
54
+ try:
55
+ from app.services.ml_service import predict, _models_loaded, load_models
56
+
57
+ if not _models_loaded:
58
+ load_models()
59
+
60
+ if not data.ecg_lead_off and len(data.ecg_samples) >= 100:
61
+ # Get user profile for personalized prediction
62
+ device_doc = await db.devices.find_one({"device_id": data.device_id})
63
+ user_profile = None
64
+ history_features = None
65
+
66
+ if device_doc and device_doc.get("owner_user_id"):
67
+ user = await db.users.find_one({"_id": ObjectId(device_doc["owner_user_id"])})
68
+ if user and user.get("profile"):
69
+ user_profile = user["profile"]
70
+
71
+ # Compute historical baselines
72
+ from datetime import timedelta
73
+ now = datetime.utcnow()
74
+ pipeline_24h = [
75
+ {"$match": {"device_id": data.device_id,
76
+ "created_at": {"$gte": now - timedelta(hours=24)}}},
77
+ {"$group": {
78
+ "_id": None,
79
+ "avg_hr": {"$avg": "$heart_rate_bpm"},
80
+ "std_hr": {"$stdDevPop": "$heart_rate_bpm"},
81
+ "avg_spo2": {"$avg": "$spo2_percent"},
82
+ "std_spo2": {"$stdDevPop": "$spo2_percent"},
83
+ "count": {"$sum": 1},
84
+ }},
85
+ ]
86
+ pipeline_7d = [
87
+ {"$match": {"device_id": data.device_id,
88
+ "created_at": {"$gte": now - timedelta(days=7)}}},
89
+ {"$group": {
90
+ "_id": None,
91
+ "avg_hr": {"$avg": "$heart_rate_bpm"},
92
+ "avg_spo2": {"$avg": "$spo2_percent"},
93
+ }},
94
+ ]
95
+
96
+ stats_24h = await db.vitals.aggregate(pipeline_24h).to_list(1)
97
+ stats_7d = await db.vitals.aggregate(pipeline_7d).to_list(1)
98
+
99
+ if stats_24h:
100
+ s = stats_24h[0]
101
+ hr_std = s.get("std_hr", 1) or 1
102
+ spo2_std = s.get("std_spo2", 1) or 1
103
+ history_features = {
104
+ "hr_baseline_24h": s.get("avg_hr", 0),
105
+ "spo2_baseline_24h": s.get("avg_spo2", 0),
106
+ "hr_deviation": abs(data.heart_rate_bpm - s.get("avg_hr", data.heart_rate_bpm)) / hr_std,
107
+ "spo2_deviation": abs(data.spo2_percent - s.get("avg_spo2", data.spo2_percent)) / spo2_std,
108
+ "readings_count_24h": s.get("count", 0),
109
+ }
110
+ if stats_7d:
111
+ if history_features is None:
112
+ history_features = {}
113
+ history_features["hr_baseline_7d"] = stats_7d[0].get("avg_hr", 0)
114
+
115
+ ml_result = predict(
116
+ ecg_samples=data.ecg_samples,
117
+ sample_rate_hz=data.sample_rate_hz,
118
+ heart_rate_bpm=data.heart_rate_bpm,
119
+ spo2_percent=data.spo2_percent,
120
+ user_profile=user_profile,
121
+ history_features=history_features,
122
+ )
123
+
124
+ if ml_result["risk_label"] != "unknown":
125
+ pred_doc = {
126
+ "vitals_id": str(result.inserted_id),
127
+ "device_id": data.device_id,
128
+ "risk_score": ml_result["risk_score"],
129
+ "risk_label": ml_result["risk_label"],
130
+ "confidence": ml_result["confidence"],
131
+ "features": ml_result["features"],
132
+ "model_version": ml_result["model_version"],
133
+ "created_at": datetime.utcnow(),
134
+ }
135
+ await db.predictions.insert_one(pred_doc)
136
+ prediction = {
137
+ "risk_score": ml_result["risk_score"],
138
+ "risk_label": ml_result["risk_label"],
139
+ "confidence": ml_result["confidence"],
140
+ }
141
+ except ImportError:
142
+ pass # ML dependencies not installed, skip prediction
143
+ except Exception as e:
144
+ print(f"[ML] Prediction error: {e}")
145
+
146
+ return _vitals_doc_to_response(vitals_doc, prediction)
147
+
148
+
149
+ @router.get("/{device_id}/latest", response_model=VitalsResponse)
150
+ async def get_latest_vitals(device_id: str, _=Depends(get_current_user)):
151
+ db = get_db()
152
+
153
+ doc = await db.vitals.find_one(
154
+ {"device_id": device_id},
155
+ sort=[("timestamp", -1)],
156
+ )
157
+ if not doc:
158
+ raise HTTPException(status_code=404, detail="No vitals found for this device")
159
+
160
+ # Attach latest prediction if exists
161
+ pred = await db.predictions.find_one(
162
+ {"vitals_id": str(doc["_id"])},
163
+ )
164
+ pred_dict = None
165
+ if pred:
166
+ pred_dict = {
167
+ "risk_score": pred["risk_score"],
168
+ "risk_label": pred["risk_label"],
169
+ "confidence": pred["confidence"],
170
+ }
171
+
172
+ return _vitals_doc_to_response(doc, pred_dict)
173
+
174
+
175
+ @router.get("/{device_id}", response_model=VitalsListResponse)
176
+ async def get_vitals_history(
177
+ device_id: str,
178
+ limit: int = Query(default=50, ge=1, le=500),
179
+ offset: int = Query(default=0, ge=0),
180
+ _=Depends(get_current_user),
181
+ ):
182
+ db = get_db()
183
+
184
+ total = await db.vitals.count_documents({"device_id": device_id})
185
+ cursor = (
186
+ db.vitals.find(
187
+ {"device_id": device_id},
188
+ {"ecg_samples": 0}, # Exclude raw samples from list view
189
+ )
190
+ .sort("timestamp", -1)
191
+ .skip(offset)
192
+ .limit(limit)
193
+ )
194
+ docs = await cursor.to_list(length=limit)
195
+
196
+ vitals_list = [_vitals_doc_to_response(doc) for doc in docs]
197
+ return VitalsListResponse(vitals=vitals_list, total=total)
app/services/__init__.py ADDED
File without changes
app/services/feature_extractor.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ECG Feature Extraction for backend inference.
3
+ Thin wrapper that imports from ml/src/feature_extractor.py to stay in sync.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+
9
+ # Import from the bundled ML source
10
+ ML_SRC = os.path.join(os.path.dirname(__file__), "..", "..", "ml_src")
11
+ sys.path.insert(0, ML_SRC)
12
+
13
+ from feature_extractor import (
14
+ extract_ecg_features,
15
+ features_to_array,
16
+ FEATURE_NAMES,
17
+ PROFILE_FEATURE_NAMES,
18
+ HISTORY_FEATURE_NAMES,
19
+ ALL_FEATURE_NAMES,
20
+ )
21
+
22
+ __all__ = [
23
+ "extract_ecg_features",
24
+ "features_to_array",
25
+ "FEATURE_NAMES",
26
+ "PROFILE_FEATURE_NAMES",
27
+ "HISTORY_FEATURE_NAMES",
28
+ "ALL_FEATURE_NAMES",
29
+ ]
app/services/ml_service.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ML inference service for cardiac risk prediction.
3
+ Loads ECGFounder + XGBoost ensemble and runs prediction on incoming vitals.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+ from scipy.signal import resample
10
+ import torch
11
+ import joblib
12
+
13
+ # Import model architecture from bundled ml_src/
14
+ BACKEND_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
15
+ ML_SRC = os.path.join(BACKEND_ROOT, "ml_src")
16
+ sys.path.insert(0, ML_SRC)
17
+
18
+ from net1d import Net1D
19
+ from feature_extractor import extract_ecg_features, features_to_array, FEATURE_NAMES
20
+
21
+ # Model paths — bundled in backend/ml_models/
22
+ MODEL_DIR = os.path.join(BACKEND_ROOT, "ml_models")
23
+
24
+ # Ensemble weights
25
+ ECG_WEIGHT = 0.60
26
+ XGB_WEIGHT = 0.40
27
+
28
+ # Risk labels
29
+ RISK_LABELS = {
30
+ (0.0, 0.2): "normal",
31
+ (0.2, 0.4): "low",
32
+ (0.4, 0.6): "moderate",
33
+ (0.6, 0.8): "elevated",
34
+ (0.8, 1.01): "high",
35
+ }
36
+
37
+ # Global model instances
38
+ _ecg_model = None
39
+ _xgb_model = None
40
+ _device = None
41
+ _models_loaded = False
42
+
43
+
44
+ def get_risk_label(score: float) -> str:
45
+ for (low, high), label in RISK_LABELS.items():
46
+ if low <= score < high:
47
+ return label
48
+ return "high"
49
+
50
+
51
+ def load_models():
52
+ """Load both models into memory. Called once at startup."""
53
+ global _ecg_model, _xgb_model, _device, _models_loaded
54
+
55
+ if _models_loaded:
56
+ return True
57
+
58
+ # Select device
59
+ if torch.cuda.is_available():
60
+ _device = torch.device("cuda:0")
61
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
62
+ _device = torch.device("mps")
63
+ else:
64
+ _device = torch.device("cpu")
65
+
66
+ ecg_path = os.path.join(MODEL_DIR, "ecgfounder_best.pt")
67
+ xgb_path = os.path.join(MODEL_DIR, "xgboost_cardiac.joblib")
68
+
69
+ # Load ECGFounder
70
+ if os.path.exists(ecg_path):
71
+ try:
72
+ _ecg_model = Net1D(
73
+ in_channels=1,
74
+ base_filters=64,
75
+ ratio=1,
76
+ filter_list=[64, 160, 160, 400, 400, 1024, 1024],
77
+ m_blocks_list=[2, 2, 2, 3, 3, 4, 4],
78
+ kernel_size=16,
79
+ stride=2,
80
+ groups_width=16,
81
+ verbose=False,
82
+ use_bn=False,
83
+ use_do=False,
84
+ n_classes=1,
85
+ )
86
+ state_dict = torch.load(ecg_path, map_location=_device, weights_only=False)
87
+ _ecg_model.load_state_dict(state_dict)
88
+ _ecg_model.to(_device)
89
+ _ecg_model.eval()
90
+ print(f"[ML] ECGFounder loaded on {_device}")
91
+ except Exception as e:
92
+ print(f"[ML] Failed to load ECGFounder: {e}")
93
+ _ecg_model = None
94
+ else:
95
+ print(f"[ML] ECGFounder not found at {ecg_path}")
96
+
97
+ # Load XGBoost
98
+ if os.path.exists(xgb_path):
99
+ try:
100
+ _xgb_model = joblib.load(xgb_path)
101
+ print("[ML] XGBoost loaded")
102
+ except Exception as e:
103
+ print(f"[ML] Failed to load XGBoost: {e}")
104
+ _xgb_model = None
105
+ else:
106
+ print(f"[ML] XGBoost not found at {xgb_path}")
107
+
108
+ _models_loaded = True
109
+ return _ecg_model is not None or _xgb_model is not None
110
+
111
+
112
+ def predict(ecg_samples: list, sample_rate_hz: int = 100,
113
+ heart_rate_bpm: float = None, spo2_percent: float = None,
114
+ user_profile: dict = None, history_features: dict = None) -> dict:
115
+ """
116
+ Run ensemble prediction on ECG data.
117
+
118
+ Args:
119
+ ecg_samples: list of ECG ADC values from ESP32
120
+ sample_rate_hz: device sample rate (100Hz for ESP32)
121
+ heart_rate_bpm: HR from MAX30100
122
+ spo2_percent: SpO2 from MAX30100
123
+ user_profile: dict with age, sex, bmi, is_diabetic, etc.
124
+ history_features: dict with hr_baseline_24h, etc.
125
+
126
+ Returns:
127
+ dict with risk_score, risk_label, confidence, features, model_version
128
+ """
129
+ if not _models_loaded:
130
+ load_models()
131
+
132
+ ecg = np.array(ecg_samples, dtype=np.float32)
133
+ result = {
134
+ "risk_score": 0.0,
135
+ "risk_label": "unknown",
136
+ "confidence": 0.0,
137
+ "features": {},
138
+ "model_version": "v1.0-ensemble",
139
+ }
140
+
141
+ prob_ecg = None
142
+ prob_xgb = None
143
+
144
+ # --- ECGFounder Prediction ---
145
+ if _ecg_model is not None:
146
+ try:
147
+ # Upsample from device rate to 500Hz (5000 samples for 10s)
148
+ target_length = 5000
149
+ if len(ecg) != target_length:
150
+ ecg_500hz = resample(ecg, target_length)
151
+ else:
152
+ ecg_500hz = ecg.copy()
153
+
154
+ # Z-score normalize
155
+ mean = np.mean(ecg_500hz)
156
+ std = np.std(ecg_500hz) + 1e-8
157
+ ecg_500hz = (ecg_500hz - mean) / std
158
+ ecg_500hz = np.nan_to_num(ecg_500hz, nan=0.0)
159
+
160
+ # Shape: (1, 1, 5000)
161
+ tensor = torch.FloatTensor(ecg_500hz).reshape(1, 1, -1).to(_device)
162
+
163
+ with torch.no_grad():
164
+ logit = _ecg_model(tensor)
165
+ prob_ecg = float(torch.sigmoid(logit).cpu().item())
166
+
167
+ except Exception as e:
168
+ print(f"[ML] ECGFounder inference error: {e}")
169
+ prob_ecg = None
170
+
171
+ # --- XGBoost Prediction ---
172
+ if _xgb_model is not None:
173
+ try:
174
+ # Extract features at device sample rate
175
+ features = extract_ecg_features(
176
+ ecg, sample_rate=sample_rate_hz,
177
+ heart_rate_sensor=heart_rate_bpm,
178
+ spo2=spo2_percent,
179
+ )
180
+
181
+ # Add user profile features
182
+ if user_profile:
183
+ features["age"] = user_profile.get("age", 50)
184
+ features["sex"] = 1 if user_profile.get("sex") == "male" else 0
185
+ h = user_profile.get("height_cm", 170) / 100
186
+ w = user_profile.get("weight_kg", 70)
187
+ features["bmi"] = w / (h * h) if h > 0 else 25.0
188
+ features["is_diabetic"] = 1 if user_profile.get("is_diabetic") else 0
189
+ features["is_hypertensive"] = 1 if user_profile.get("is_hypertensive") else 0
190
+ features["is_smoker"] = 1 if user_profile.get("is_smoker") else 0
191
+ features["family_history"] = 1 if user_profile.get("family_history") else 0
192
+
193
+ # Add historical baseline features
194
+ if history_features:
195
+ for key in ["hr_baseline_24h", "hr_baseline_7d", "spo2_baseline_24h",
196
+ "hr_deviation", "spo2_deviation", "resting_hr_trend",
197
+ "readings_count_24h"]:
198
+ features[key] = history_features.get(key, 0.0)
199
+
200
+ # Convert to array (25 ECG features only for base model)
201
+ feat_array = features_to_array(features, include_profile=False).reshape(1, -1)
202
+ prob_xgb = float(_xgb_model.predict_proba(feat_array)[0, 1])
203
+
204
+ # Store features in result
205
+ result["features"] = {k: round(v, 4) for k, v in features.items()
206
+ if k in FEATURE_NAMES[:10]} # Top 10 for response size
207
+
208
+ except Exception as e:
209
+ print(f"[ML] XGBoost inference error: {e}")
210
+ prob_xgb = None
211
+
212
+ # --- Ensemble ---
213
+ if prob_ecg is not None and prob_xgb is not None:
214
+ risk_score = ECG_WEIGHT * prob_ecg + XGB_WEIGHT * prob_xgb
215
+ confidence = 1.0 - abs(prob_ecg - prob_xgb) # Higher when models agree
216
+ result["model_version"] = "v1.0-ensemble"
217
+ elif prob_ecg is not None:
218
+ risk_score = prob_ecg
219
+ confidence = 0.7
220
+ result["model_version"] = "v1.0-ecgfounder-only"
221
+ elif prob_xgb is not None:
222
+ risk_score = prob_xgb
223
+ confidence = 0.6
224
+ result["model_version"] = "v1.0-xgboost-only"
225
+ else:
226
+ return result # No models available
227
+
228
+ result["risk_score"] = round(float(risk_score), 4)
229
+ result["risk_label"] = get_risk_label(risk_score)
230
+ result["confidence"] = round(float(confidence), 4)
231
+
232
+ return result
ml_src/__init__.py ADDED
File without changes
ml_src/feature_extractor.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ECG Feature Extraction for XGBoost model.
3
+ Extracts 26 signal features from single-lead ECG using NeuroKit2.
4
+ This module is shared between ml/ training and backend/ inference.
5
+ """
6
+
7
+ import numpy as np
8
+ import neurokit2 as nk
9
+ from scipy.stats import kurtosis, skew
10
+
11
+
12
+ def extract_ecg_features(ecg_signal: np.ndarray, sample_rate: int = 100,
13
+ heart_rate_sensor: float = None,
14
+ spo2: float = None) -> dict:
15
+ """
16
+ Extract 26 features from single-lead ECG signal.
17
+
18
+ Args:
19
+ ecg_signal: 1D numpy array of ECG samples
20
+ sample_rate: Sampling rate in Hz (100 for ESP32, 500 for PTB-XL)
21
+ heart_rate_sensor: HR from MAX30100 (optional, for device features)
22
+ spo2: SpO2 from MAX30100 (optional, for device features)
23
+
24
+ Returns:
25
+ dict of 26 features (keys match XGBoost training feature names)
26
+ """
27
+ features = {}
28
+
29
+ try:
30
+ # Clean the ECG signal
31
+ ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sample_rate)
32
+
33
+ # Detect R-peaks
34
+ _, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sample_rate)
35
+ r_peak_indices = rpeaks.get("ECG_R_Peaks", np.array([]))
36
+
37
+ if len(r_peak_indices) < 3:
38
+ return _fallback_features(ecg_signal, heart_rate_sensor, spo2)
39
+
40
+ # --- HRV Time-Domain Features (7) ---
41
+ rr_intervals = np.diff(r_peak_indices) / sample_rate * 1000 # ms
42
+
43
+ features["mean_rr"] = float(np.mean(rr_intervals))
44
+ features["sdnn"] = float(np.std(rr_intervals, ddof=1)) if len(rr_intervals) > 1 else 0.0
45
+ features["rmssd"] = float(np.sqrt(np.mean(np.diff(rr_intervals) ** 2))) if len(rr_intervals) > 1 else 0.0
46
+
47
+ nn_diff = np.abs(np.diff(rr_intervals))
48
+ features["pnn50"] = float(np.sum(nn_diff > 50) / len(nn_diff) * 100) if len(nn_diff) > 0 else 0.0
49
+
50
+ hr_from_rr = 60000.0 / rr_intervals
51
+ features["mean_hr_ecg"] = float(np.mean(hr_from_rr))
52
+ features["hr_std"] = float(np.std(hr_from_rr))
53
+ features["rr_range"] = float(np.max(rr_intervals) - np.min(rr_intervals))
54
+
55
+ # --- ECG Morphology Features (9) ---
56
+ try:
57
+ # Delineate ECG waves
58
+ _, waves = nk.ecg_delineate(ecg_cleaned, rpeaks, sampling_rate=sample_rate, method="dwt")
59
+
60
+ # QRS duration
61
+ qrs_onsets = [x for x in waves.get("ECG_Q_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
62
+ qrs_offsets = [x for x in waves.get("ECG_S_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
63
+ if qrs_onsets and qrs_offsets:
64
+ qrs_durations = []
65
+ for q, s in zip(qrs_onsets[:len(qrs_offsets)], qrs_offsets[:len(qrs_onsets)]):
66
+ qrs_durations.append(abs(s - q) / sample_rate * 1000)
67
+ features["qrs_duration"] = float(np.mean(qrs_durations)) if qrs_durations else 100.0
68
+ else:
69
+ features["qrs_duration"] = 100.0
70
+
71
+ # R amplitude
72
+ r_amplitudes = ecg_cleaned[r_peak_indices.astype(int)]
73
+ features["r_amplitude"] = float(np.mean(r_amplitudes))
74
+ features["r_amplitude_std"] = float(np.std(r_amplitudes))
75
+
76
+ # QT interval
77
+ t_offsets = [x for x in waves.get("ECG_T_Offsets", []) if isinstance(x, (int, float)) and not np.isnan(x)]
78
+ if qrs_onsets and t_offsets:
79
+ qt_intervals = []
80
+ for q, t in zip(qrs_onsets[:len(t_offsets)], t_offsets[:len(qrs_onsets)]):
81
+ qt_intervals.append(abs(t - q) / sample_rate * 1000)
82
+ features["qt_interval"] = float(np.mean(qt_intervals)) if qt_intervals else 400.0
83
+ # Bazett's QTc
84
+ mean_rr_sec = features["mean_rr"] / 1000
85
+ features["qtc"] = float(features["qt_interval"] / np.sqrt(mean_rr_sec)) if mean_rr_sec > 0 else 440.0
86
+ else:
87
+ features["qt_interval"] = 400.0
88
+ features["qtc"] = 440.0
89
+
90
+ # ST level (amplitude at J-point, ~40ms after R-peak)
91
+ j_offset = int(0.04 * sample_rate)
92
+ st_levels = []
93
+ for rp in r_peak_indices.astype(int):
94
+ j_idx = rp + j_offset
95
+ if j_idx < len(ecg_cleaned):
96
+ st_levels.append(ecg_cleaned[j_idx])
97
+ features["st_level"] = float(np.mean(st_levels)) if st_levels else 0.0
98
+
99
+ # T-wave amplitude
100
+ t_peaks = [x for x in waves.get("ECG_T_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
101
+ if t_peaks:
102
+ t_amps = [ecg_cleaned[int(t)] for t in t_peaks if int(t) < len(ecg_cleaned)]
103
+ features["t_amplitude"] = float(np.mean(t_amps)) if t_amps else 0.0
104
+ else:
105
+ features["t_amplitude"] = 0.0
106
+
107
+ # P-wave ratio (P amplitude / R amplitude)
108
+ p_peaks = [x for x in waves.get("ECG_P_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
109
+ if p_peaks and features["r_amplitude"] != 0:
110
+ p_amps = [ecg_cleaned[int(p)] for p in p_peaks if int(p) < len(ecg_cleaned)]
111
+ features["p_wave_ratio"] = float(np.mean(p_amps) / features["r_amplitude"]) if p_amps else 0.1
112
+ else:
113
+ features["p_wave_ratio"] = 0.1
114
+
115
+ except Exception:
116
+ features.setdefault("qrs_duration", 100.0)
117
+ features.setdefault("r_amplitude", float(np.max(ecg_cleaned) - np.min(ecg_cleaned)))
118
+ features.setdefault("r_amplitude_std", 0.0)
119
+ features.setdefault("qt_interval", 400.0)
120
+ features.setdefault("qtc", 440.0)
121
+ features.setdefault("st_level", 0.0)
122
+ features.setdefault("t_amplitude", 0.0)
123
+ features.setdefault("p_wave_ratio", 0.1)
124
+
125
+ # --- Signal Statistics (6) ---
126
+ features["rms"] = float(np.sqrt(np.mean(ecg_cleaned ** 2)))
127
+ features["entropy"] = float(_sample_entropy(ecg_cleaned))
128
+ features["zero_crossing_rate"] = float(
129
+ np.sum(np.diff(np.sign(ecg_cleaned - np.mean(ecg_cleaned))) != 0) / len(ecg_cleaned)
130
+ )
131
+ features["kurtosis"] = float(kurtosis(ecg_cleaned))
132
+ features["skewness"] = float(skew(ecg_cleaned))
133
+ features["snr"] = float(_estimate_snr(ecg_cleaned, sample_rate))
134
+
135
+ # --- Device Sensor Features (4) ---
136
+ features["heart_rate_sensor"] = float(heart_rate_sensor) if heart_rate_sensor else features["mean_hr_ecg"]
137
+ features["spo2"] = float(spo2) if spo2 else 97.0
138
+
139
+ hr_diff = abs(features["heart_rate_sensor"] - features["mean_hr_ecg"])
140
+ features["hr_sensor_ecg_diff"] = float(hr_diff)
141
+
142
+ # ECG quality score (based on peak regularity)
143
+ if len(rr_intervals) > 1:
144
+ cv = np.std(rr_intervals) / np.mean(rr_intervals)
145
+ features["ecg_quality"] = float(max(0, 1 - cv))
146
+ else:
147
+ features["ecg_quality"] = 0.5
148
+
149
+ except Exception:
150
+ return _fallback_features(ecg_signal, heart_rate_sensor, spo2)
151
+
152
+ return features
153
+
154
+
155
+ def _sample_entropy(signal, m=2, r_factor=0.2):
156
+ """Approximate sample entropy."""
157
+ try:
158
+ r = r_factor * np.std(signal)
159
+ N = len(signal)
160
+ if N < m + 2 or r == 0:
161
+ return 0.0
162
+
163
+ # Use simplified approach for speed
164
+ templates_m = np.array([signal[i:i + m] for i in range(N - m)])
165
+ templates_m1 = np.array([signal[i:i + m + 1] for i in range(N - m - 1)])
166
+
167
+ count_m = 0
168
+ count_m1 = 0
169
+
170
+ # Sample subset for speed
171
+ n_check = min(200, len(templates_m))
172
+ indices = np.random.choice(len(templates_m), n_check, replace=False) if len(templates_m) > n_check else range(len(templates_m))
173
+
174
+ for i in indices:
175
+ dist_m = np.max(np.abs(templates_m - templates_m[i]), axis=1)
176
+ count_m += np.sum(dist_m < r) - 1
177
+
178
+ if i < len(templates_m1):
179
+ dist_m1 = np.max(np.abs(templates_m1 - templates_m1[i]), axis=1)
180
+ count_m1 += np.sum(dist_m1 < r) - 1
181
+
182
+ if count_m == 0 or count_m1 == 0:
183
+ return 0.0
184
+
185
+ return -np.log(count_m1 / count_m)
186
+ except Exception:
187
+ return 0.0
188
+
189
+
190
+ def _estimate_snr(signal, sample_rate):
191
+ """Estimate signal-to-noise ratio."""
192
+ try:
193
+ cleaned = nk.ecg_clean(signal, sampling_rate=sample_rate)
194
+ noise = signal - cleaned
195
+ signal_power = np.mean(cleaned ** 2)
196
+ noise_power = np.mean(noise ** 2)
197
+ if noise_power == 0:
198
+ return 30.0
199
+ return float(10 * np.log10(signal_power / noise_power))
200
+ except Exception:
201
+ return 10.0
202
+
203
+
204
+ def _fallback_features(ecg_signal, heart_rate_sensor=None, spo2=None) -> dict:
205
+ """Return default features when ECG processing fails."""
206
+ return {
207
+ "mean_rr": 800.0, "sdnn": 50.0, "rmssd": 30.0, "pnn50": 10.0,
208
+ "mean_hr_ecg": 75.0, "hr_std": 5.0, "rr_range": 200.0,
209
+ "qrs_duration": 100.0, "r_amplitude": 1.0, "r_amplitude_std": 0.1,
210
+ "qt_interval": 400.0, "qtc": 440.0, "st_level": 0.0,
211
+ "t_amplitude": 0.3, "p_wave_ratio": 0.1,
212
+ "rms": float(np.sqrt(np.mean(ecg_signal ** 2))) if len(ecg_signal) > 0 else 0.5,
213
+ "entropy": 0.5, "zero_crossing_rate": 0.1,
214
+ "kurtosis": 0.0, "skewness": 0.0, "snr": 10.0,
215
+ "heart_rate_sensor": float(heart_rate_sensor) if heart_rate_sensor else 75.0,
216
+ "spo2": float(spo2) if spo2 else 97.0,
217
+ "hr_sensor_ecg_diff": 0.0, "ecg_quality": 0.5,
218
+ }
219
+
220
+
221
+ # Ordered feature names for XGBoost (must match training order)
222
+ FEATURE_NAMES = [
223
+ "mean_rr", "sdnn", "rmssd", "pnn50", "mean_hr_ecg", "hr_std", "rr_range",
224
+ "qrs_duration", "r_amplitude", "r_amplitude_std", "qt_interval", "qtc",
225
+ "st_level", "t_amplitude", "p_wave_ratio",
226
+ "rms", "entropy", "zero_crossing_rate", "kurtosis", "skewness", "snr",
227
+ "heart_rate_sensor", "spo2", "hr_sensor_ecg_diff", "ecg_quality",
228
+ ]
229
+
230
+ # User profile feature names (appended after ECG features)
231
+ PROFILE_FEATURE_NAMES = [
232
+ "age", "sex", "bmi", "is_diabetic", "is_hypertensive",
233
+ "is_smoker", "family_history",
234
+ ]
235
+
236
+ # Historical baseline feature names
237
+ HISTORY_FEATURE_NAMES = [
238
+ "hr_baseline_24h", "hr_baseline_7d", "spo2_baseline_24h",
239
+ "hr_deviation", "spo2_deviation", "resting_hr_trend", "readings_count_24h",
240
+ ]
241
+
242
+ ALL_FEATURE_NAMES = FEATURE_NAMES + PROFILE_FEATURE_NAMES + HISTORY_FEATURE_NAMES
243
+
244
+
245
+ def features_to_array(features: dict, include_profile: bool = False) -> np.ndarray:
246
+ """Convert feature dict to numpy array in correct order for XGBoost."""
247
+ names = ALL_FEATURE_NAMES if include_profile else FEATURE_NAMES
248
+ return np.array([features.get(name, 0.0) for name in names], dtype=np.float32)
ml_src/net1d.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Net1D: 1D CNN with Squeeze-and-Excitation for ECG classification.
3
+ From PKUDigitalHealth/ECGFounder (MIT License).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class MyConv1dPadSame(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
13
+ super().__init__()
14
+ self.in_channels = in_channels
15
+ self.out_channels = out_channels
16
+ self.kernel_size = kernel_size
17
+ self.stride = stride
18
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, groups=groups)
19
+
20
+ def forward(self, x):
21
+ in_dim = x.shape[-1]
22
+ out_dim = (in_dim + self.stride - 1) // self.stride
23
+ p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
24
+ pad_left = p // 2
25
+ pad_right = p - pad_left
26
+ x = F.pad(x, (pad_left, pad_right), "constant", 0)
27
+ return self.conv(x)
28
+
29
+
30
+ class MyMaxPool1dPadSame(nn.Module):
31
+ def __init__(self, kernel_size):
32
+ super().__init__()
33
+ self.kernel_size = kernel_size
34
+ self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
35
+
36
+ def forward(self, x):
37
+ p = max(0, self.kernel_size - 1)
38
+ pad_left = p // 2
39
+ pad_right = p - pad_left
40
+ x = F.pad(x, (pad_left, pad_right), "constant", 0)
41
+ return self.max_pool(x)
42
+
43
+
44
+ class Swish(nn.Module):
45
+ def forward(self, x):
46
+ return x * torch.sigmoid(x)
47
+
48
+
49
+ class BasicBlock(nn.Module):
50
+ def __init__(self, in_channels, out_channels, ratio, kernel_size, stride,
51
+ groups, downsample, is_first_block=False, use_bn=True, use_do=True):
52
+ super().__init__()
53
+ self.in_channels = in_channels
54
+ self.out_channels = out_channels
55
+ self.downsample = downsample
56
+ self.stride = stride if downsample else 1
57
+ self.is_first_block = is_first_block
58
+ self.use_bn = use_bn
59
+ self.use_do = use_do
60
+ middle = int(out_channels * ratio)
61
+
62
+ self.bn1 = nn.BatchNorm1d(in_channels)
63
+ self.activation1 = Swish()
64
+ self.do1 = nn.Dropout(p=0.5)
65
+ self.conv1 = MyConv1dPadSame(in_channels, middle, 1, 1, 1)
66
+
67
+ self.bn2 = nn.BatchNorm1d(middle)
68
+ self.activation2 = Swish()
69
+ self.do2 = nn.Dropout(p=0.5)
70
+ self.conv2 = MyConv1dPadSame(middle, middle, kernel_size, self.stride, groups)
71
+
72
+ self.bn3 = nn.BatchNorm1d(middle)
73
+ self.activation3 = Swish()
74
+ self.do3 = nn.Dropout(p=0.5)
75
+ self.conv3 = MyConv1dPadSame(middle, out_channels, 1, 1, 1)
76
+
77
+ # Squeeze-and-Excitation
78
+ r = 2
79
+ self.se_fc1 = nn.Linear(out_channels, out_channels // r)
80
+ self.se_fc2 = nn.Linear(out_channels // r, out_channels)
81
+ self.se_activation = Swish()
82
+
83
+ if self.downsample:
84
+ self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)
85
+
86
+ def forward(self, x):
87
+ identity = x
88
+ out = x
89
+
90
+ if not self.is_first_block:
91
+ if self.use_bn:
92
+ out = self.bn1(out)
93
+ out = self.activation1(out)
94
+ if self.use_do:
95
+ out = self.do1(out)
96
+ out = self.conv1(out)
97
+
98
+ if self.use_bn:
99
+ out = self.bn2(out)
100
+ out = self.activation2(out)
101
+ if self.use_do:
102
+ out = self.do2(out)
103
+ out = self.conv2(out)
104
+
105
+ if self.use_bn:
106
+ out = self.bn3(out)
107
+ out = self.activation3(out)
108
+ if self.use_do:
109
+ out = self.do3(out)
110
+ out = self.conv3(out)
111
+
112
+ # SE attention
113
+ se = out.mean(-1)
114
+ se = self.se_fc1(se)
115
+ se = self.se_activation(se)
116
+ se = self.se_fc2(se)
117
+ se = torch.sigmoid(se)
118
+ out = torch.einsum('abc,ab->abc', out, se)
119
+
120
+ if self.downsample:
121
+ identity = self.max_pool(identity)
122
+ if self.out_channels != self.in_channels:
123
+ identity = identity.transpose(-1, -2)
124
+ ch1 = (self.out_channels - self.in_channels) // 2
125
+ ch2 = self.out_channels - self.in_channels - ch1
126
+ identity = F.pad(identity, (ch1, ch2), "constant", 0)
127
+ identity = identity.transpose(-1, -2)
128
+
129
+ out += identity
130
+ return out
131
+
132
+
133
+ class BasicStage(nn.Module):
134
+ def __init__(self, in_channels, out_channels, ratio, kernel_size, stride,
135
+ groups, i_stage, m_blocks, use_bn=True, use_do=True):
136
+ super().__init__()
137
+ self.block_list = nn.ModuleList()
138
+ for i_block in range(m_blocks):
139
+ is_first = (i_stage == 0 and i_block == 0)
140
+ if i_block == 0:
141
+ tmp_block = BasicBlock(
142
+ in_channels, out_channels, ratio, kernel_size,
143
+ stride, groups, downsample=True,
144
+ is_first_block=is_first, use_bn=use_bn, use_do=use_do)
145
+ else:
146
+ tmp_block = BasicBlock(
147
+ out_channels, out_channels, ratio, kernel_size,
148
+ 1, groups, downsample=False,
149
+ is_first_block=False, use_bn=use_bn, use_do=use_do)
150
+ self.block_list.append(tmp_block)
151
+
152
+ def forward(self, x):
153
+ for block in self.block_list:
154
+ x = block(x)
155
+ return x
156
+
157
+
158
+ class Net1D(nn.Module):
159
+ """
160
+ 1D CNN for ECG classification.
161
+ Input: (batch, in_channels, length)
162
+ Output: (batch, n_classes)
163
+ """
164
+ def __init__(self, in_channels, base_filters, ratio, filter_list,
165
+ m_blocks_list, kernel_size, stride, groups_width,
166
+ n_classes, use_bn=True, use_do=True, verbose=False):
167
+ super().__init__()
168
+ self.n_stages = len(filter_list)
169
+ self.use_bn = use_bn
170
+
171
+ self.first_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size, stride=2)
172
+ self.first_bn = nn.BatchNorm1d(base_filters)
173
+ self.first_activation = Swish()
174
+
175
+ self.stage_list = nn.ModuleList()
176
+ in_ch = base_filters
177
+ for i_stage in range(self.n_stages):
178
+ out_ch = filter_list[i_stage]
179
+ self.stage_list.append(BasicStage(
180
+ in_ch, out_ch, ratio, kernel_size, stride,
181
+ out_ch // groups_width, i_stage, m_blocks_list[i_stage],
182
+ use_bn=use_bn, use_do=use_do))
183
+ in_ch = out_ch
184
+
185
+ self.dense = nn.Linear(in_ch, n_classes)
186
+
187
+ def forward(self, x):
188
+ out = self.first_conv(x)
189
+ if self.use_bn:
190
+ out = self.first_bn(out)
191
+ out = self.first_activation(out)
192
+
193
+ for stage in self.stage_list:
194
+ out = stage(out)
195
+
196
+ features = out.mean(-1) # Global Average Pooling
197
+ out = self.dense(features)
198
+ return out
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.6
2
+ uvicorn[standard]==0.34.0
3
+ motor==3.6.0
4
+ pymongo==4.10.1
5
+ pydantic==2.10.0
6
+ pydantic-settings==2.7.0
7
+ python-jose[cryptography]==3.3.0
8
+ bcrypt==4.2.0
9
+ python-multipart==0.0.18
10
+ numpy==1.26.4
11
+ scipy==1.14.1
12
+ httpx==0.28.1
13
+ xgboost>=2.0.0
14
+ neurokit2>=0.2.7
15
+ joblib>=1.3.0