Spaces:
Runtime error
Runtime error
Commit ·
11e9a40
1
Parent(s): a86336f
Deploy cardiac monitor FastAPI backend
Browse files- FastAPI with JWT auth, health profiles, vitals, ML predictions
- ECGFounder + XGBoost ensemble for cardiac risk assessment
- Dockerfile downloads models from GitHub LFS during build
- CPU-only PyTorch for optimized container size
- .dockerignore +9 -0
- Dockerfile +28 -0
- README.md +15 -4
- app/__init__.py +0 -0
- app/config.py +16 -0
- app/database.py +31 -0
- app/main.py +46 -0
- app/middleware/__init__.py +0 -0
- app/middleware/auth.py +54 -0
- app/models/__init__.py +0 -0
- app/models/device.py +16 -0
- app/models/prediction.py +15 -0
- app/models/user.py +41 -0
- app/models/vitals.py +32 -0
- app/routes/__init__.py +0 -0
- app/routes/auth.py +85 -0
- app/routes/devices.py +83 -0
- app/routes/health.py +17 -0
- app/routes/predictions.py +69 -0
- app/routes/vitals.py +197 -0
- app/services/__init__.py +0 -0
- app/services/feature_extractor.py +29 -0
- app/services/ml_service.py +232 -0
- ml_src/__init__.py +0 -0
- ml_src/feature_extractor.py +248 -0
- ml_src/net1d.py +198 -0
- requirements.txt +15 -0
.dockerignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
.env
|
| 4 |
+
.env.*
|
| 5 |
+
.git
|
| 6 |
+
.gitignore
|
| 7 |
+
*.md
|
| 8 |
+
.venv
|
| 9 |
+
venv
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install CPU-only PyTorch first (saves ~1.5GB vs full torch)
|
| 6 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 7 |
+
|
| 8 |
+
# Install remaining dependencies
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy application code
|
| 13 |
+
COPY app/ app/
|
| 14 |
+
COPY ml_src/ ml_src/
|
| 15 |
+
|
| 16 |
+
# Download ML models from GitHub LFS (avoids LFS issues on HF Spaces)
|
| 17 |
+
RUN mkdir -p ml_models && \
|
| 18 |
+
apt-get update && apt-get install -y --no-install-recommends curl && \
|
| 19 |
+
curl -L -o ml_models/ecgfounder_best.pt \
|
| 20 |
+
"https://media.githubusercontent.com/media/Sanuka23/cardiac-monitor/main/backend/ml_models/ecgfounder_best.pt" && \
|
| 21 |
+
curl -L -o ml_models/xgboost_cardiac.joblib \
|
| 22 |
+
"https://media.githubusercontent.com/media/Sanuka23/cardiac-monitor/main/backend/ml_models/xgboost_cardiac.joblib" && \
|
| 23 |
+
apt-get purge -y curl && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* && \
|
| 24 |
+
echo "Model sizes:" && ls -lh ml_models/
|
| 25 |
+
|
| 26 |
+
EXPOSE 7860
|
| 27 |
+
|
| 28 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,12 +1,23 @@
|
|
| 1 |
---
|
| 2 |
title: Cardiac Monitor API
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: red
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
-
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Cardiac Monitor API
|
| 3 |
+
emoji: ❤️
|
| 4 |
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
+
short_description: Cardiac monitoring API with ML risk prediction
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Cardiac Monitor API
|
| 14 |
+
|
| 15 |
+
FastAPI backend for ESP32 Heart Rate, SpO2 & ECG cardiac monitoring system with ML-based risk prediction.
|
| 16 |
+
|
| 17 |
+
## Endpoints
|
| 18 |
+
|
| 19 |
+
- `GET /api/v1/health` — Health check
|
| 20 |
+
- `POST /api/v1/auth/register` — Register
|
| 21 |
+
- `POST /api/v1/auth/login` — Login (JWT)
|
| 22 |
+
- `GET /api/v1/vitals/{device_id}` — Vitals history
|
| 23 |
+
- `GET /api/v1/predictions/{device_id}` — Risk predictions
|
app/__init__.py
ADDED
|
File without changes
|
app/config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
MONGODB_URI: str = "mongodb://localhost:27017"
|
| 6 |
+
DATABASE_NAME: str = "cardiac_monitor"
|
| 7 |
+
JWT_SECRET: str = "changeme"
|
| 8 |
+
JWT_ALGORITHM: str = "HS256"
|
| 9 |
+
JWT_EXPIRY_HOURS: int = 24
|
| 10 |
+
API_KEY: str = "dev-api-key"
|
| 11 |
+
|
| 12 |
+
class Config:
|
| 13 |
+
env_file = ".env"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
settings = Settings()
|
app/database.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from motor.motor_asyncio import AsyncIOMotorClient
|
| 2 |
+
|
| 3 |
+
client: AsyncIOMotorClient = None
|
| 4 |
+
db = None
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
async def connect_db(uri: str, db_name: str):
|
| 8 |
+
global client, db
|
| 9 |
+
client = AsyncIOMotorClient(uri, serverSelectionTimeoutMS=5000)
|
| 10 |
+
db = client[db_name]
|
| 11 |
+
|
| 12 |
+
# Create indexes (non-blocking — if DB is unreachable, server still starts)
|
| 13 |
+
try:
|
| 14 |
+
await db.users.create_index("email", unique=True)
|
| 15 |
+
await db.devices.create_index("device_id", unique=True)
|
| 16 |
+
await db.vitals.create_index([("device_id", 1), ("timestamp", -1)])
|
| 17 |
+
await db.predictions.create_index([("device_id", 1), ("created_at", -1)])
|
| 18 |
+
print("[DB] Connected and indexes created.")
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"[DB] Warning: Could not create indexes: {e}")
|
| 21 |
+
print("[DB] Server will start but DB operations may fail until MongoDB is available.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
async def close_db():
|
| 25 |
+
global client
|
| 26 |
+
if client:
|
| 27 |
+
client.close()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_db():
|
| 31 |
+
return db
|
app/main.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
|
| 6 |
+
from app.config import settings
|
| 7 |
+
from app.database import connect_db, close_db
|
| 8 |
+
from app.routes import vitals, auth, devices, health, predictions
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@asynccontextmanager
|
| 12 |
+
async def lifespan(app: FastAPI):
|
| 13 |
+
await connect_db(settings.MONGODB_URI, settings.DATABASE_NAME)
|
| 14 |
+
|
| 15 |
+
# Load ML models (non-blocking — server starts even if models aren't ready)
|
| 16 |
+
try:
|
| 17 |
+
from app.services.ml_service import load_models
|
| 18 |
+
load_models()
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"[ML] Could not load models: {e}")
|
| 21 |
+
print("[ML] Server will run without predictions until models are available.")
|
| 22 |
+
|
| 23 |
+
yield
|
| 24 |
+
await close_db()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
app = FastAPI(
|
| 28 |
+
title="Cardiac Monitor API",
|
| 29 |
+
version="1.0.0",
|
| 30 |
+
description="Backend for ESP32 Heart Rate, SpO2 & ECG Monitor",
|
| 31 |
+
lifespan=lifespan,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
app.add_middleware(
|
| 35 |
+
CORSMiddleware,
|
| 36 |
+
allow_origins=["*"],
|
| 37 |
+
allow_credentials=True,
|
| 38 |
+
allow_methods=["*"],
|
| 39 |
+
allow_headers=["*"],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
app.include_router(health.router, prefix="/api/v1", tags=["health"])
|
| 43 |
+
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
| 44 |
+
app.include_router(vitals.router, prefix="/api/v1/vitals", tags=["vitals"])
|
| 45 |
+
app.include_router(predictions.router, prefix="/api/v1/predictions", tags=["predictions"])
|
| 46 |
+
app.include_router(devices.router, prefix="/api/v1/devices", tags=["devices"])
|
app/middleware/__init__.py
ADDED
|
File without changes
|
app/middleware/auth.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
|
| 3 |
+
import bcrypt
|
| 4 |
+
from bson import ObjectId
|
| 5 |
+
from fastapi import Depends, HTTPException, Header
|
| 6 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 7 |
+
from jose import jwt, JWTError
|
| 8 |
+
|
| 9 |
+
from app.config import settings
|
| 10 |
+
from app.database import get_db
|
| 11 |
+
|
| 12 |
+
security = HTTPBearer()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def hash_password(password: str) -> str:
|
| 16 |
+
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def verify_password(plain: str, hashed: str) -> bool:
|
| 20 |
+
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_access_token(user_id: str) -> str:
|
| 24 |
+
expire = datetime.utcnow() + timedelta(hours=settings.JWT_EXPIRY_HOURS)
|
| 25 |
+
payload = {"sub": user_id, "exp": expire}
|
| 26 |
+
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
async def verify_api_key(x_api_key: str = Header(...)):
|
| 30 |
+
if x_api_key != settings.API_KEY:
|
| 31 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 32 |
+
return x_api_key
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def get_current_user(
|
| 36 |
+
credentials: HTTPAuthorizationCredentials = Depends(security),
|
| 37 |
+
):
|
| 38 |
+
try:
|
| 39 |
+
payload = jwt.decode(
|
| 40 |
+
credentials.credentials,
|
| 41 |
+
settings.JWT_SECRET,
|
| 42 |
+
algorithms=[settings.JWT_ALGORITHM],
|
| 43 |
+
)
|
| 44 |
+
user_id = payload.get("sub")
|
| 45 |
+
if user_id is None:
|
| 46 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 47 |
+
except JWTError:
|
| 48 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 49 |
+
|
| 50 |
+
db = get_db()
|
| 51 |
+
user = await db.users.find_one({"_id": ObjectId(user_id)})
|
| 52 |
+
if user is None:
|
| 53 |
+
raise HTTPException(status_code=401, detail="User not found")
|
| 54 |
+
return user
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/device.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DeviceRegister(BaseModel):
|
| 7 |
+
device_id: str = Field(..., min_length=1, max_length=50)
|
| 8 |
+
name: Optional[str] = Field(None, max_length=100)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DeviceResponse(BaseModel):
|
| 12 |
+
device_id: str
|
| 13 |
+
name: Optional[str] = None
|
| 14 |
+
owner_user_id: Optional[str] = None
|
| 15 |
+
last_seen: Optional[datetime] = None
|
| 16 |
+
registered_at: datetime
|
app/models/prediction.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PredictionResponse(BaseModel):
|
| 7 |
+
id: str
|
| 8 |
+
vitals_id: str
|
| 9 |
+
device_id: str
|
| 10 |
+
risk_score: float = Field(..., ge=0.0, le=1.0)
|
| 11 |
+
risk_label: str
|
| 12 |
+
confidence: float = Field(..., ge=0.0, le=1.0)
|
| 13 |
+
features: dict
|
| 14 |
+
model_version: str
|
| 15 |
+
created_at: datetime
|
app/models/user.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, EmailStr, Field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HealthProfile(BaseModel):
|
| 7 |
+
age: Optional[int] = Field(None, ge=1, le=120)
|
| 8 |
+
sex: Optional[str] = Field(None, pattern="^(male|female|other)$")
|
| 9 |
+
height_cm: Optional[float] = Field(None, ge=50, le=300)
|
| 10 |
+
weight_kg: Optional[float] = Field(None, ge=10, le=500)
|
| 11 |
+
is_diabetic: bool = False
|
| 12 |
+
is_hypertensive: bool = False
|
| 13 |
+
is_smoker: bool = False
|
| 14 |
+
family_history: bool = False
|
| 15 |
+
known_conditions: List[str] = Field(default_factory=list)
|
| 16 |
+
medications: List[str] = Field(default_factory=list)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UserRegister(BaseModel):
|
| 20 |
+
email: str = Field(..., min_length=5, max_length=100)
|
| 21 |
+
password: str = Field(..., min_length=6, max_length=100)
|
| 22 |
+
name: str = Field(..., min_length=1, max_length=100)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class UserLogin(BaseModel):
|
| 26 |
+
email: str
|
| 27 |
+
password: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class UserResponse(BaseModel):
|
| 31 |
+
id: str
|
| 32 |
+
email: str
|
| 33 |
+
name: str
|
| 34 |
+
device_ids: List[str] = Field(default_factory=list)
|
| 35 |
+
profile: Optional[HealthProfile] = None
|
| 36 |
+
created_at: datetime
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TokenResponse(BaseModel):
|
| 40 |
+
access_token: str
|
| 41 |
+
token_type: str = "bearer"
|
app/models/vitals.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class VitalsCreate(BaseModel):
|
| 7 |
+
device_id: str = Field(..., min_length=1, max_length=50)
|
| 8 |
+
timestamp: int = Field(..., description="Unix epoch seconds from ESP32")
|
| 9 |
+
window_ms: int = Field(default=10000, ge=1000, le=60000)
|
| 10 |
+
sample_rate_hz: int = Field(default=100, ge=50, le=1000)
|
| 11 |
+
heart_rate_bpm: float = Field(..., ge=0, le=300)
|
| 12 |
+
spo2_percent: int = Field(..., ge=0, le=100)
|
| 13 |
+
ecg_lead_off: bool = Field(default=False)
|
| 14 |
+
ecg_samples: List[int] = Field(..., min_length=100, max_length=6000)
|
| 15 |
+
beat_timestamps_ms: List[int] = Field(default_factory=list)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VitalsResponse(BaseModel):
|
| 19 |
+
id: str
|
| 20 |
+
device_id: str
|
| 21 |
+
timestamp: datetime
|
| 22 |
+
heart_rate_bpm: float
|
| 23 |
+
spo2_percent: int
|
| 24 |
+
ecg_lead_off: bool
|
| 25 |
+
sample_count: int
|
| 26 |
+
prediction: Optional[dict] = None
|
| 27 |
+
created_at: datetime
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class VitalsListResponse(BaseModel):
|
| 31 |
+
vitals: List[VitalsResponse]
|
| 32 |
+
total: int
|
app/routes/__init__.py
ADDED
|
File without changes
|
app/routes/auth.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 4 |
+
|
| 5 |
+
from app.database import get_db
|
| 6 |
+
from app.middleware.auth import (
|
| 7 |
+
hash_password,
|
| 8 |
+
verify_password,
|
| 9 |
+
create_access_token,
|
| 10 |
+
get_current_user,
|
| 11 |
+
)
|
| 12 |
+
from app.models.user import (
|
| 13 |
+
UserRegister,
|
| 14 |
+
UserLogin,
|
| 15 |
+
UserResponse,
|
| 16 |
+
TokenResponse,
|
| 17 |
+
HealthProfile,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
router = APIRouter()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@router.post("/register", response_model=TokenResponse)
|
| 24 |
+
async def register(data: UserRegister):
|
| 25 |
+
db = get_db()
|
| 26 |
+
|
| 27 |
+
existing = await db.users.find_one({"email": data.email})
|
| 28 |
+
if existing:
|
| 29 |
+
raise HTTPException(status_code=400, detail="Email already registered")
|
| 30 |
+
|
| 31 |
+
user_doc = {
|
| 32 |
+
"email": data.email,
|
| 33 |
+
"password_hash": hash_password(data.password),
|
| 34 |
+
"name": data.name,
|
| 35 |
+
"device_ids": [],
|
| 36 |
+
"profile": None,
|
| 37 |
+
"created_at": datetime.utcnow(),
|
| 38 |
+
}
|
| 39 |
+
result = await db.users.insert_one(user_doc)
|
| 40 |
+
token = create_access_token(str(result.inserted_id))
|
| 41 |
+
return TokenResponse(access_token=token)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@router.post("/login", response_model=TokenResponse)
|
| 45 |
+
async def login(data: UserLogin):
|
| 46 |
+
db = get_db()
|
| 47 |
+
|
| 48 |
+
user = await db.users.find_one({"email": data.email})
|
| 49 |
+
if not user or not verify_password(data.password, user["password_hash"]):
|
| 50 |
+
raise HTTPException(status_code=401, detail="Invalid email or password")
|
| 51 |
+
|
| 52 |
+
token = create_access_token(str(user["_id"]))
|
| 53 |
+
return TokenResponse(access_token=token)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@router.get("/me", response_model=UserResponse)
|
| 57 |
+
async def get_me(user=Depends(get_current_user)):
|
| 58 |
+
return UserResponse(
|
| 59 |
+
id=str(user["_id"]),
|
| 60 |
+
email=user["email"],
|
| 61 |
+
name=user["name"],
|
| 62 |
+
device_ids=user.get("device_ids", []),
|
| 63 |
+
profile=user.get("profile"),
|
| 64 |
+
created_at=user["created_at"],
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@router.put("/profile", response_model=UserResponse)
|
| 69 |
+
async def update_profile(profile: HealthProfile, user=Depends(get_current_user)):
|
| 70 |
+
db = get_db()
|
| 71 |
+
|
| 72 |
+
await db.users.update_one(
|
| 73 |
+
{"_id": user["_id"]},
|
| 74 |
+
{"$set": {"profile": profile.model_dump()}},
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
user["profile"] = profile.model_dump()
|
| 78 |
+
return UserResponse(
|
| 79 |
+
id=str(user["_id"]),
|
| 80 |
+
email=user["email"],
|
| 81 |
+
name=user["name"],
|
| 82 |
+
device_ids=user.get("device_ids", []),
|
| 83 |
+
profile=user["profile"],
|
| 84 |
+
created_at=user["created_at"],
|
| 85 |
+
)
|
app/routes/devices.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 4 |
+
|
| 5 |
+
from app.database import get_db
|
| 6 |
+
from app.middleware.auth import get_current_user
|
| 7 |
+
from app.models.device import DeviceRegister, DeviceResponse
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.post("/register", response_model=DeviceResponse)
|
| 13 |
+
async def register_device(data: DeviceRegister, user=Depends(get_current_user)):
|
| 14 |
+
db = get_db()
|
| 15 |
+
|
| 16 |
+
existing = await db.devices.find_one({"device_id": data.device_id})
|
| 17 |
+
if existing:
|
| 18 |
+
# If device already exists, link to this user if not already linked
|
| 19 |
+
if existing.get("owner_user_id") and str(existing["owner_user_id"]) != str(
|
| 20 |
+
user["_id"]
|
| 21 |
+
):
|
| 22 |
+
raise HTTPException(
|
| 23 |
+
status_code=400, detail="Device already registered to another user"
|
| 24 |
+
)
|
| 25 |
+
await db.devices.update_one(
|
| 26 |
+
{"device_id": data.device_id},
|
| 27 |
+
{
|
| 28 |
+
"$set": {
|
| 29 |
+
"owner_user_id": str(user["_id"]),
|
| 30 |
+
"name": data.name or existing.get("name"),
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
)
|
| 34 |
+
existing["owner_user_id"] = str(user["_id"])
|
| 35 |
+
if data.name:
|
| 36 |
+
existing["name"] = data.name
|
| 37 |
+
doc = existing
|
| 38 |
+
else:
|
| 39 |
+
doc = {
|
| 40 |
+
"device_id": data.device_id,
|
| 41 |
+
"name": data.name,
|
| 42 |
+
"owner_user_id": str(user["_id"]),
|
| 43 |
+
"last_seen": None,
|
| 44 |
+
"registered_at": datetime.utcnow(),
|
| 45 |
+
}
|
| 46 |
+
await db.devices.insert_one(doc)
|
| 47 |
+
|
| 48 |
+
# Add device_id to user's device list
|
| 49 |
+
await db.users.update_one(
|
| 50 |
+
{"_id": user["_id"]},
|
| 51 |
+
{"$addToSet": {"device_ids": data.device_id}},
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return DeviceResponse(
|
| 55 |
+
device_id=doc["device_id"],
|
| 56 |
+
name=doc.get("name"),
|
| 57 |
+
owner_user_id=doc.get("owner_user_id"),
|
| 58 |
+
last_seen=doc.get("last_seen"),
|
| 59 |
+
registered_at=doc.get("registered_at", datetime.utcnow()),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.get("", response_model=list[DeviceResponse])
|
| 64 |
+
async def list_devices(user=Depends(get_current_user)):
|
| 65 |
+
db = get_db()
|
| 66 |
+
|
| 67 |
+
device_ids = user.get("device_ids", [])
|
| 68 |
+
if not device_ids:
|
| 69 |
+
return []
|
| 70 |
+
|
| 71 |
+
cursor = db.devices.find({"device_id": {"$in": device_ids}})
|
| 72 |
+
docs = await cursor.to_list(length=100)
|
| 73 |
+
|
| 74 |
+
return [
|
| 75 |
+
DeviceResponse(
|
| 76 |
+
device_id=doc["device_id"],
|
| 77 |
+
name=doc.get("name"),
|
| 78 |
+
owner_user_id=doc.get("owner_user_id"),
|
| 79 |
+
last_seen=doc.get("last_seen"),
|
| 80 |
+
registered_at=doc.get("registered_at", datetime.utcnow()),
|
| 81 |
+
)
|
| 82 |
+
for doc in docs
|
| 83 |
+
]
|
app/routes/health.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
from app.database import get_db
|
| 4 |
+
|
| 5 |
+
router = APIRouter()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.get("/health")
|
| 9 |
+
async def health_check():
|
| 10 |
+
db = get_db()
|
| 11 |
+
try:
|
| 12 |
+
await db.command("ping")
|
| 13 |
+
db_status = "connected"
|
| 14 |
+
except Exception:
|
| 15 |
+
db_status = "disconnected"
|
| 16 |
+
|
| 17 |
+
return {"status": "ok", "db": db_status, "version": "1.0.0"}
|
app/routes/predictions.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 2 |
+
|
| 3 |
+
from app.database import get_db
|
| 4 |
+
from app.middleware.auth import get_current_user
|
| 5 |
+
from app.models.prediction import PredictionResponse
|
| 6 |
+
|
| 7 |
+
router = APIRouter()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@router.get("/{device_id}/latest", response_model=PredictionResponse)
|
| 11 |
+
async def get_latest_prediction(device_id: str, _=Depends(get_current_user)):
|
| 12 |
+
db = get_db()
|
| 13 |
+
|
| 14 |
+
doc = await db.predictions.find_one(
|
| 15 |
+
{"device_id": device_id},
|
| 16 |
+
sort=[("created_at", -1)],
|
| 17 |
+
)
|
| 18 |
+
if not doc:
|
| 19 |
+
raise HTTPException(
|
| 20 |
+
status_code=404, detail="No predictions found for this device"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
return PredictionResponse(
|
| 24 |
+
id=str(doc["_id"]),
|
| 25 |
+
vitals_id=doc["vitals_id"],
|
| 26 |
+
device_id=doc["device_id"],
|
| 27 |
+
risk_score=doc["risk_score"],
|
| 28 |
+
risk_label=doc["risk_label"],
|
| 29 |
+
confidence=doc["confidence"],
|
| 30 |
+
features=doc.get("features", {}),
|
| 31 |
+
model_version=doc.get("model_version", "none"),
|
| 32 |
+
created_at=doc["created_at"],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@router.get("/{device_id}")
|
| 37 |
+
async def get_prediction_history(
|
| 38 |
+
device_id: str,
|
| 39 |
+
limit: int = Query(default=50, ge=1, le=500),
|
| 40 |
+
offset: int = Query(default=0, ge=0),
|
| 41 |
+
_=Depends(get_current_user),
|
| 42 |
+
):
|
| 43 |
+
db = get_db()
|
| 44 |
+
|
| 45 |
+
cursor = (
|
| 46 |
+
db.predictions.find({"device_id": device_id})
|
| 47 |
+
.sort("created_at", -1)
|
| 48 |
+
.skip(offset)
|
| 49 |
+
.limit(limit)
|
| 50 |
+
)
|
| 51 |
+
docs = await cursor.to_list(length=limit)
|
| 52 |
+
|
| 53 |
+
results = [
|
| 54 |
+
PredictionResponse(
|
| 55 |
+
id=str(doc["_id"]),
|
| 56 |
+
vitals_id=doc["vitals_id"],
|
| 57 |
+
device_id=doc["device_id"],
|
| 58 |
+
risk_score=doc["risk_score"],
|
| 59 |
+
risk_label=doc["risk_label"],
|
| 60 |
+
confidence=doc["confidence"],
|
| 61 |
+
features=doc.get("features", {}),
|
| 62 |
+
model_version=doc.get("model_version", "none"),
|
| 63 |
+
created_at=doc["created_at"],
|
| 64 |
+
)
|
| 65 |
+
for doc in docs
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
total = await db.predictions.count_documents({"device_id": device_id})
|
| 69 |
+
return {"predictions": results, "total": total}
|
app/routes/vitals.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
from bson import ObjectId
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 5 |
+
|
| 6 |
+
from app.database import get_db
|
| 7 |
+
from app.middleware.auth import verify_api_key, get_current_user
|
| 8 |
+
from app.models.vitals import VitalsCreate, VitalsResponse, VitalsListResponse
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _vitals_doc_to_response(doc: dict, prediction: dict = None) -> VitalsResponse:
|
| 14 |
+
return VitalsResponse(
|
| 15 |
+
id=str(doc["_id"]),
|
| 16 |
+
device_id=doc["device_id"],
|
| 17 |
+
timestamp=doc["timestamp"],
|
| 18 |
+
heart_rate_bpm=doc["heart_rate_bpm"],
|
| 19 |
+
spo2_percent=doc["spo2_percent"],
|
| 20 |
+
ecg_lead_off=doc["ecg_lead_off"],
|
| 21 |
+
sample_count=len(doc.get("ecg_samples", [])),
|
| 22 |
+
prediction=prediction,
|
| 23 |
+
created_at=doc["created_at"],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@router.post("", response_model=VitalsResponse)
|
| 28 |
+
async def upload_vitals(data: VitalsCreate, _=Depends(verify_api_key)):
|
| 29 |
+
db = get_db()
|
| 30 |
+
|
| 31 |
+
vitals_doc = {
|
| 32 |
+
"device_id": data.device_id,
|
| 33 |
+
"timestamp": datetime.utcfromtimestamp(data.timestamp),
|
| 34 |
+
"window_ms": data.window_ms,
|
| 35 |
+
"sample_rate_hz": data.sample_rate_hz,
|
| 36 |
+
"heart_rate_bpm": data.heart_rate_bpm,
|
| 37 |
+
"spo2_percent": data.spo2_percent,
|
| 38 |
+
"ecg_lead_off": data.ecg_lead_off,
|
| 39 |
+
"ecg_samples": data.ecg_samples,
|
| 40 |
+
"beat_timestamps_ms": data.beat_timestamps_ms,
|
| 41 |
+
"created_at": datetime.utcnow(),
|
| 42 |
+
}
|
| 43 |
+
result = await db.vitals.insert_one(vitals_doc)
|
| 44 |
+
vitals_doc["_id"] = result.inserted_id
|
| 45 |
+
|
| 46 |
+
# Update device last_seen
|
| 47 |
+
await db.devices.update_one(
|
| 48 |
+
{"device_id": data.device_id},
|
| 49 |
+
{"$set": {"last_seen": datetime.utcnow()}},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Run ML prediction if models are available
|
| 53 |
+
prediction = None
|
| 54 |
+
try:
|
| 55 |
+
from app.services.ml_service import predict, _models_loaded, load_models
|
| 56 |
+
|
| 57 |
+
if not _models_loaded:
|
| 58 |
+
load_models()
|
| 59 |
+
|
| 60 |
+
if not data.ecg_lead_off and len(data.ecg_samples) >= 100:
|
| 61 |
+
# Get user profile for personalized prediction
|
| 62 |
+
device_doc = await db.devices.find_one({"device_id": data.device_id})
|
| 63 |
+
user_profile = None
|
| 64 |
+
history_features = None
|
| 65 |
+
|
| 66 |
+
if device_doc and device_doc.get("owner_user_id"):
|
| 67 |
+
user = await db.users.find_one({"_id": ObjectId(device_doc["owner_user_id"])})
|
| 68 |
+
if user and user.get("profile"):
|
| 69 |
+
user_profile = user["profile"]
|
| 70 |
+
|
| 71 |
+
# Compute historical baselines
|
| 72 |
+
from datetime import timedelta
|
| 73 |
+
now = datetime.utcnow()
|
| 74 |
+
pipeline_24h = [
|
| 75 |
+
{"$match": {"device_id": data.device_id,
|
| 76 |
+
"created_at": {"$gte": now - timedelta(hours=24)}}},
|
| 77 |
+
{"$group": {
|
| 78 |
+
"_id": None,
|
| 79 |
+
"avg_hr": {"$avg": "$heart_rate_bpm"},
|
| 80 |
+
"std_hr": {"$stdDevPop": "$heart_rate_bpm"},
|
| 81 |
+
"avg_spo2": {"$avg": "$spo2_percent"},
|
| 82 |
+
"std_spo2": {"$stdDevPop": "$spo2_percent"},
|
| 83 |
+
"count": {"$sum": 1},
|
| 84 |
+
}},
|
| 85 |
+
]
|
| 86 |
+
pipeline_7d = [
|
| 87 |
+
{"$match": {"device_id": data.device_id,
|
| 88 |
+
"created_at": {"$gte": now - timedelta(days=7)}}},
|
| 89 |
+
{"$group": {
|
| 90 |
+
"_id": None,
|
| 91 |
+
"avg_hr": {"$avg": "$heart_rate_bpm"},
|
| 92 |
+
"avg_spo2": {"$avg": "$spo2_percent"},
|
| 93 |
+
}},
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
stats_24h = await db.vitals.aggregate(pipeline_24h).to_list(1)
|
| 97 |
+
stats_7d = await db.vitals.aggregate(pipeline_7d).to_list(1)
|
| 98 |
+
|
| 99 |
+
if stats_24h:
|
| 100 |
+
s = stats_24h[0]
|
| 101 |
+
hr_std = s.get("std_hr", 1) or 1
|
| 102 |
+
spo2_std = s.get("std_spo2", 1) or 1
|
| 103 |
+
history_features = {
|
| 104 |
+
"hr_baseline_24h": s.get("avg_hr", 0),
|
| 105 |
+
"spo2_baseline_24h": s.get("avg_spo2", 0),
|
| 106 |
+
"hr_deviation": abs(data.heart_rate_bpm - s.get("avg_hr", data.heart_rate_bpm)) / hr_std,
|
| 107 |
+
"spo2_deviation": abs(data.spo2_percent - s.get("avg_spo2", data.spo2_percent)) / spo2_std,
|
| 108 |
+
"readings_count_24h": s.get("count", 0),
|
| 109 |
+
}
|
| 110 |
+
if stats_7d:
|
| 111 |
+
if history_features is None:
|
| 112 |
+
history_features = {}
|
| 113 |
+
history_features["hr_baseline_7d"] = stats_7d[0].get("avg_hr", 0)
|
| 114 |
+
|
| 115 |
+
ml_result = predict(
|
| 116 |
+
ecg_samples=data.ecg_samples,
|
| 117 |
+
sample_rate_hz=data.sample_rate_hz,
|
| 118 |
+
heart_rate_bpm=data.heart_rate_bpm,
|
| 119 |
+
spo2_percent=data.spo2_percent,
|
| 120 |
+
user_profile=user_profile,
|
| 121 |
+
history_features=history_features,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if ml_result["risk_label"] != "unknown":
|
| 125 |
+
pred_doc = {
|
| 126 |
+
"vitals_id": str(result.inserted_id),
|
| 127 |
+
"device_id": data.device_id,
|
| 128 |
+
"risk_score": ml_result["risk_score"],
|
| 129 |
+
"risk_label": ml_result["risk_label"],
|
| 130 |
+
"confidence": ml_result["confidence"],
|
| 131 |
+
"features": ml_result["features"],
|
| 132 |
+
"model_version": ml_result["model_version"],
|
| 133 |
+
"created_at": datetime.utcnow(),
|
| 134 |
+
}
|
| 135 |
+
await db.predictions.insert_one(pred_doc)
|
| 136 |
+
prediction = {
|
| 137 |
+
"risk_score": ml_result["risk_score"],
|
| 138 |
+
"risk_label": ml_result["risk_label"],
|
| 139 |
+
"confidence": ml_result["confidence"],
|
| 140 |
+
}
|
| 141 |
+
except ImportError:
|
| 142 |
+
pass # ML dependencies not installed, skip prediction
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"[ML] Prediction error: {e}")
|
| 145 |
+
|
| 146 |
+
return _vitals_doc_to_response(vitals_doc, prediction)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@router.get("/{device_id}/latest", response_model=VitalsResponse)
|
| 150 |
+
async def get_latest_vitals(device_id: str, _=Depends(get_current_user)):
|
| 151 |
+
db = get_db()
|
| 152 |
+
|
| 153 |
+
doc = await db.vitals.find_one(
|
| 154 |
+
{"device_id": device_id},
|
| 155 |
+
sort=[("timestamp", -1)],
|
| 156 |
+
)
|
| 157 |
+
if not doc:
|
| 158 |
+
raise HTTPException(status_code=404, detail="No vitals found for this device")
|
| 159 |
+
|
| 160 |
+
# Attach latest prediction if exists
|
| 161 |
+
pred = await db.predictions.find_one(
|
| 162 |
+
{"vitals_id": str(doc["_id"])},
|
| 163 |
+
)
|
| 164 |
+
pred_dict = None
|
| 165 |
+
if pred:
|
| 166 |
+
pred_dict = {
|
| 167 |
+
"risk_score": pred["risk_score"],
|
| 168 |
+
"risk_label": pred["risk_label"],
|
| 169 |
+
"confidence": pred["confidence"],
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
return _vitals_doc_to_response(doc, pred_dict)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
@router.get("/{device_id}", response_model=VitalsListResponse)
|
| 176 |
+
async def get_vitals_history(
|
| 177 |
+
device_id: str,
|
| 178 |
+
limit: int = Query(default=50, ge=1, le=500),
|
| 179 |
+
offset: int = Query(default=0, ge=0),
|
| 180 |
+
_=Depends(get_current_user),
|
| 181 |
+
):
|
| 182 |
+
db = get_db()
|
| 183 |
+
|
| 184 |
+
total = await db.vitals.count_documents({"device_id": device_id})
|
| 185 |
+
cursor = (
|
| 186 |
+
db.vitals.find(
|
| 187 |
+
{"device_id": device_id},
|
| 188 |
+
{"ecg_samples": 0}, # Exclude raw samples from list view
|
| 189 |
+
)
|
| 190 |
+
.sort("timestamp", -1)
|
| 191 |
+
.skip(offset)
|
| 192 |
+
.limit(limit)
|
| 193 |
+
)
|
| 194 |
+
docs = await cursor.to_list(length=limit)
|
| 195 |
+
|
| 196 |
+
vitals_list = [_vitals_doc_to_response(doc) for doc in docs]
|
| 197 |
+
return VitalsListResponse(vitals=vitals_list, total=total)
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/feature_extractor.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ECG Feature Extraction for backend inference.
|
| 3 |
+
Thin wrapper that imports from ml/src/feature_extractor.py to stay in sync.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
# Import from the bundled ML source
|
| 10 |
+
ML_SRC = os.path.join(os.path.dirname(__file__), "..", "..", "ml_src")
|
| 11 |
+
sys.path.insert(0, ML_SRC)
|
| 12 |
+
|
| 13 |
+
from feature_extractor import (
|
| 14 |
+
extract_ecg_features,
|
| 15 |
+
features_to_array,
|
| 16 |
+
FEATURE_NAMES,
|
| 17 |
+
PROFILE_FEATURE_NAMES,
|
| 18 |
+
HISTORY_FEATURE_NAMES,
|
| 19 |
+
ALL_FEATURE_NAMES,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"extract_ecg_features",
|
| 24 |
+
"features_to_array",
|
| 25 |
+
"FEATURE_NAMES",
|
| 26 |
+
"PROFILE_FEATURE_NAMES",
|
| 27 |
+
"HISTORY_FEATURE_NAMES",
|
| 28 |
+
"ALL_FEATURE_NAMES",
|
| 29 |
+
]
|
app/services/ml_service.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ML inference service for cardiac risk prediction.
|
| 3 |
+
Loads ECGFounder + XGBoost ensemble and runs prediction on incoming vitals.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy.signal import resample
|
| 10 |
+
import torch
|
| 11 |
+
import joblib
|
| 12 |
+
|
| 13 |
+
# Import model architecture from bundled ml_src/
|
| 14 |
+
BACKEND_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
| 15 |
+
ML_SRC = os.path.join(BACKEND_ROOT, "ml_src")
|
| 16 |
+
sys.path.insert(0, ML_SRC)
|
| 17 |
+
|
| 18 |
+
from net1d import Net1D
|
| 19 |
+
from feature_extractor import extract_ecg_features, features_to_array, FEATURE_NAMES
|
| 20 |
+
|
| 21 |
+
# Model paths — bundled in backend/ml_models/
|
| 22 |
+
MODEL_DIR = os.path.join(BACKEND_ROOT, "ml_models")
|
| 23 |
+
|
| 24 |
+
# Ensemble weights
|
| 25 |
+
ECG_WEIGHT = 0.60
|
| 26 |
+
XGB_WEIGHT = 0.40
|
| 27 |
+
|
| 28 |
+
# Risk labels
|
| 29 |
+
RISK_LABELS = {
|
| 30 |
+
(0.0, 0.2): "normal",
|
| 31 |
+
(0.2, 0.4): "low",
|
| 32 |
+
(0.4, 0.6): "moderate",
|
| 33 |
+
(0.6, 0.8): "elevated",
|
| 34 |
+
(0.8, 1.01): "high",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Global model instances
|
| 38 |
+
_ecg_model = None
|
| 39 |
+
_xgb_model = None
|
| 40 |
+
_device = None
|
| 41 |
+
_models_loaded = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_risk_label(score: float) -> str:
|
| 45 |
+
for (low, high), label in RISK_LABELS.items():
|
| 46 |
+
if low <= score < high:
|
| 47 |
+
return label
|
| 48 |
+
return "high"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_models():
|
| 52 |
+
"""Load both models into memory. Called once at startup."""
|
| 53 |
+
global _ecg_model, _xgb_model, _device, _models_loaded
|
| 54 |
+
|
| 55 |
+
if _models_loaded:
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
# Select device
|
| 59 |
+
if torch.cuda.is_available():
|
| 60 |
+
_device = torch.device("cuda:0")
|
| 61 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 62 |
+
_device = torch.device("mps")
|
| 63 |
+
else:
|
| 64 |
+
_device = torch.device("cpu")
|
| 65 |
+
|
| 66 |
+
ecg_path = os.path.join(MODEL_DIR, "ecgfounder_best.pt")
|
| 67 |
+
xgb_path = os.path.join(MODEL_DIR, "xgboost_cardiac.joblib")
|
| 68 |
+
|
| 69 |
+
# Load ECGFounder
|
| 70 |
+
if os.path.exists(ecg_path):
|
| 71 |
+
try:
|
| 72 |
+
_ecg_model = Net1D(
|
| 73 |
+
in_channels=1,
|
| 74 |
+
base_filters=64,
|
| 75 |
+
ratio=1,
|
| 76 |
+
filter_list=[64, 160, 160, 400, 400, 1024, 1024],
|
| 77 |
+
m_blocks_list=[2, 2, 2, 3, 3, 4, 4],
|
| 78 |
+
kernel_size=16,
|
| 79 |
+
stride=2,
|
| 80 |
+
groups_width=16,
|
| 81 |
+
verbose=False,
|
| 82 |
+
use_bn=False,
|
| 83 |
+
use_do=False,
|
| 84 |
+
n_classes=1,
|
| 85 |
+
)
|
| 86 |
+
state_dict = torch.load(ecg_path, map_location=_device, weights_only=False)
|
| 87 |
+
_ecg_model.load_state_dict(state_dict)
|
| 88 |
+
_ecg_model.to(_device)
|
| 89 |
+
_ecg_model.eval()
|
| 90 |
+
print(f"[ML] ECGFounder loaded on {_device}")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"[ML] Failed to load ECGFounder: {e}")
|
| 93 |
+
_ecg_model = None
|
| 94 |
+
else:
|
| 95 |
+
print(f"[ML] ECGFounder not found at {ecg_path}")
|
| 96 |
+
|
| 97 |
+
# Load XGBoost
|
| 98 |
+
if os.path.exists(xgb_path):
|
| 99 |
+
try:
|
| 100 |
+
_xgb_model = joblib.load(xgb_path)
|
| 101 |
+
print("[ML] XGBoost loaded")
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"[ML] Failed to load XGBoost: {e}")
|
| 104 |
+
_xgb_model = None
|
| 105 |
+
else:
|
| 106 |
+
print(f"[ML] XGBoost not found at {xgb_path}")
|
| 107 |
+
|
| 108 |
+
_models_loaded = True
|
| 109 |
+
return _ecg_model is not None or _xgb_model is not None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def predict(ecg_samples: list, sample_rate_hz: int = 100,
|
| 113 |
+
heart_rate_bpm: float = None, spo2_percent: float = None,
|
| 114 |
+
user_profile: dict = None, history_features: dict = None) -> dict:
|
| 115 |
+
"""
|
| 116 |
+
Run ensemble prediction on ECG data.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
ecg_samples: list of ECG ADC values from ESP32
|
| 120 |
+
sample_rate_hz: device sample rate (100Hz for ESP32)
|
| 121 |
+
heart_rate_bpm: HR from MAX30100
|
| 122 |
+
spo2_percent: SpO2 from MAX30100
|
| 123 |
+
user_profile: dict with age, sex, bmi, is_diabetic, etc.
|
| 124 |
+
history_features: dict with hr_baseline_24h, etc.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
dict with risk_score, risk_label, confidence, features, model_version
|
| 128 |
+
"""
|
| 129 |
+
if not _models_loaded:
|
| 130 |
+
load_models()
|
| 131 |
+
|
| 132 |
+
ecg = np.array(ecg_samples, dtype=np.float32)
|
| 133 |
+
result = {
|
| 134 |
+
"risk_score": 0.0,
|
| 135 |
+
"risk_label": "unknown",
|
| 136 |
+
"confidence": 0.0,
|
| 137 |
+
"features": {},
|
| 138 |
+
"model_version": "v1.0-ensemble",
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
prob_ecg = None
|
| 142 |
+
prob_xgb = None
|
| 143 |
+
|
| 144 |
+
# --- ECGFounder Prediction ---
|
| 145 |
+
if _ecg_model is not None:
|
| 146 |
+
try:
|
| 147 |
+
# Upsample from device rate to 500Hz (5000 samples for 10s)
|
| 148 |
+
target_length = 5000
|
| 149 |
+
if len(ecg) != target_length:
|
| 150 |
+
ecg_500hz = resample(ecg, target_length)
|
| 151 |
+
else:
|
| 152 |
+
ecg_500hz = ecg.copy()
|
| 153 |
+
|
| 154 |
+
# Z-score normalize
|
| 155 |
+
mean = np.mean(ecg_500hz)
|
| 156 |
+
std = np.std(ecg_500hz) + 1e-8
|
| 157 |
+
ecg_500hz = (ecg_500hz - mean) / std
|
| 158 |
+
ecg_500hz = np.nan_to_num(ecg_500hz, nan=0.0)
|
| 159 |
+
|
| 160 |
+
# Shape: (1, 1, 5000)
|
| 161 |
+
tensor = torch.FloatTensor(ecg_500hz).reshape(1, 1, -1).to(_device)
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
logit = _ecg_model(tensor)
|
| 165 |
+
prob_ecg = float(torch.sigmoid(logit).cpu().item())
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"[ML] ECGFounder inference error: {e}")
|
| 169 |
+
prob_ecg = None
|
| 170 |
+
|
| 171 |
+
# --- XGBoost Prediction ---
|
| 172 |
+
if _xgb_model is not None:
|
| 173 |
+
try:
|
| 174 |
+
# Extract features at device sample rate
|
| 175 |
+
features = extract_ecg_features(
|
| 176 |
+
ecg, sample_rate=sample_rate_hz,
|
| 177 |
+
heart_rate_sensor=heart_rate_bpm,
|
| 178 |
+
spo2=spo2_percent,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Add user profile features
|
| 182 |
+
if user_profile:
|
| 183 |
+
features["age"] = user_profile.get("age", 50)
|
| 184 |
+
features["sex"] = 1 if user_profile.get("sex") == "male" else 0
|
| 185 |
+
h = user_profile.get("height_cm", 170) / 100
|
| 186 |
+
w = user_profile.get("weight_kg", 70)
|
| 187 |
+
features["bmi"] = w / (h * h) if h > 0 else 25.0
|
| 188 |
+
features["is_diabetic"] = 1 if user_profile.get("is_diabetic") else 0
|
| 189 |
+
features["is_hypertensive"] = 1 if user_profile.get("is_hypertensive") else 0
|
| 190 |
+
features["is_smoker"] = 1 if user_profile.get("is_smoker") else 0
|
| 191 |
+
features["family_history"] = 1 if user_profile.get("family_history") else 0
|
| 192 |
+
|
| 193 |
+
# Add historical baseline features
|
| 194 |
+
if history_features:
|
| 195 |
+
for key in ["hr_baseline_24h", "hr_baseline_7d", "spo2_baseline_24h",
|
| 196 |
+
"hr_deviation", "spo2_deviation", "resting_hr_trend",
|
| 197 |
+
"readings_count_24h"]:
|
| 198 |
+
features[key] = history_features.get(key, 0.0)
|
| 199 |
+
|
| 200 |
+
# Convert to array (25 ECG features only for base model)
|
| 201 |
+
feat_array = features_to_array(features, include_profile=False).reshape(1, -1)
|
| 202 |
+
prob_xgb = float(_xgb_model.predict_proba(feat_array)[0, 1])
|
| 203 |
+
|
| 204 |
+
# Store features in result
|
| 205 |
+
result["features"] = {k: round(v, 4) for k, v in features.items()
|
| 206 |
+
if k in FEATURE_NAMES[:10]} # Top 10 for response size
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"[ML] XGBoost inference error: {e}")
|
| 210 |
+
prob_xgb = None
|
| 211 |
+
|
| 212 |
+
# --- Ensemble ---
|
| 213 |
+
if prob_ecg is not None and prob_xgb is not None:
|
| 214 |
+
risk_score = ECG_WEIGHT * prob_ecg + XGB_WEIGHT * prob_xgb
|
| 215 |
+
confidence = 1.0 - abs(prob_ecg - prob_xgb) # Higher when models agree
|
| 216 |
+
result["model_version"] = "v1.0-ensemble"
|
| 217 |
+
elif prob_ecg is not None:
|
| 218 |
+
risk_score = prob_ecg
|
| 219 |
+
confidence = 0.7
|
| 220 |
+
result["model_version"] = "v1.0-ecgfounder-only"
|
| 221 |
+
elif prob_xgb is not None:
|
| 222 |
+
risk_score = prob_xgb
|
| 223 |
+
confidence = 0.6
|
| 224 |
+
result["model_version"] = "v1.0-xgboost-only"
|
| 225 |
+
else:
|
| 226 |
+
return result # No models available
|
| 227 |
+
|
| 228 |
+
result["risk_score"] = round(float(risk_score), 4)
|
| 229 |
+
result["risk_label"] = get_risk_label(risk_score)
|
| 230 |
+
result["confidence"] = round(float(confidence), 4)
|
| 231 |
+
|
| 232 |
+
return result
|
ml_src/__init__.py
ADDED
|
File without changes
|
ml_src/feature_extractor.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ECG Feature Extraction for XGBoost model.
|
| 3 |
+
Extracts 26 signal features from single-lead ECG using NeuroKit2.
|
| 4 |
+
This module is shared between ml/ training and backend/ inference.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import neurokit2 as nk
|
| 9 |
+
from scipy.stats import kurtosis, skew
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extract_ecg_features(ecg_signal: np.ndarray, sample_rate: int = 100,
|
| 13 |
+
heart_rate_sensor: float = None,
|
| 14 |
+
spo2: float = None) -> dict:
|
| 15 |
+
"""
|
| 16 |
+
Extract 26 features from single-lead ECG signal.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
ecg_signal: 1D numpy array of ECG samples
|
| 20 |
+
sample_rate: Sampling rate in Hz (100 for ESP32, 500 for PTB-XL)
|
| 21 |
+
heart_rate_sensor: HR from MAX30100 (optional, for device features)
|
| 22 |
+
spo2: SpO2 from MAX30100 (optional, for device features)
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
dict of 26 features (keys match XGBoost training feature names)
|
| 26 |
+
"""
|
| 27 |
+
features = {}
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
# Clean the ECG signal
|
| 31 |
+
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sample_rate)
|
| 32 |
+
|
| 33 |
+
# Detect R-peaks
|
| 34 |
+
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sample_rate)
|
| 35 |
+
r_peak_indices = rpeaks.get("ECG_R_Peaks", np.array([]))
|
| 36 |
+
|
| 37 |
+
if len(r_peak_indices) < 3:
|
| 38 |
+
return _fallback_features(ecg_signal, heart_rate_sensor, spo2)
|
| 39 |
+
|
| 40 |
+
# --- HRV Time-Domain Features (7) ---
|
| 41 |
+
rr_intervals = np.diff(r_peak_indices) / sample_rate * 1000 # ms
|
| 42 |
+
|
| 43 |
+
features["mean_rr"] = float(np.mean(rr_intervals))
|
| 44 |
+
features["sdnn"] = float(np.std(rr_intervals, ddof=1)) if len(rr_intervals) > 1 else 0.0
|
| 45 |
+
features["rmssd"] = float(np.sqrt(np.mean(np.diff(rr_intervals) ** 2))) if len(rr_intervals) > 1 else 0.0
|
| 46 |
+
|
| 47 |
+
nn_diff = np.abs(np.diff(rr_intervals))
|
| 48 |
+
features["pnn50"] = float(np.sum(nn_diff > 50) / len(nn_diff) * 100) if len(nn_diff) > 0 else 0.0
|
| 49 |
+
|
| 50 |
+
hr_from_rr = 60000.0 / rr_intervals
|
| 51 |
+
features["mean_hr_ecg"] = float(np.mean(hr_from_rr))
|
| 52 |
+
features["hr_std"] = float(np.std(hr_from_rr))
|
| 53 |
+
features["rr_range"] = float(np.max(rr_intervals) - np.min(rr_intervals))
|
| 54 |
+
|
| 55 |
+
# --- ECG Morphology Features (9) ---
|
| 56 |
+
try:
|
| 57 |
+
# Delineate ECG waves
|
| 58 |
+
_, waves = nk.ecg_delineate(ecg_cleaned, rpeaks, sampling_rate=sample_rate, method="dwt")
|
| 59 |
+
|
| 60 |
+
# QRS duration
|
| 61 |
+
qrs_onsets = [x for x in waves.get("ECG_Q_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
|
| 62 |
+
qrs_offsets = [x for x in waves.get("ECG_S_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
|
| 63 |
+
if qrs_onsets and qrs_offsets:
|
| 64 |
+
qrs_durations = []
|
| 65 |
+
for q, s in zip(qrs_onsets[:len(qrs_offsets)], qrs_offsets[:len(qrs_onsets)]):
|
| 66 |
+
qrs_durations.append(abs(s - q) / sample_rate * 1000)
|
| 67 |
+
features["qrs_duration"] = float(np.mean(qrs_durations)) if qrs_durations else 100.0
|
| 68 |
+
else:
|
| 69 |
+
features["qrs_duration"] = 100.0
|
| 70 |
+
|
| 71 |
+
# R amplitude
|
| 72 |
+
r_amplitudes = ecg_cleaned[r_peak_indices.astype(int)]
|
| 73 |
+
features["r_amplitude"] = float(np.mean(r_amplitudes))
|
| 74 |
+
features["r_amplitude_std"] = float(np.std(r_amplitudes))
|
| 75 |
+
|
| 76 |
+
# QT interval
|
| 77 |
+
t_offsets = [x for x in waves.get("ECG_T_Offsets", []) if isinstance(x, (int, float)) and not np.isnan(x)]
|
| 78 |
+
if qrs_onsets and t_offsets:
|
| 79 |
+
qt_intervals = []
|
| 80 |
+
for q, t in zip(qrs_onsets[:len(t_offsets)], t_offsets[:len(qrs_onsets)]):
|
| 81 |
+
qt_intervals.append(abs(t - q) / sample_rate * 1000)
|
| 82 |
+
features["qt_interval"] = float(np.mean(qt_intervals)) if qt_intervals else 400.0
|
| 83 |
+
# Bazett's QTc
|
| 84 |
+
mean_rr_sec = features["mean_rr"] / 1000
|
| 85 |
+
features["qtc"] = float(features["qt_interval"] / np.sqrt(mean_rr_sec)) if mean_rr_sec > 0 else 440.0
|
| 86 |
+
else:
|
| 87 |
+
features["qt_interval"] = 400.0
|
| 88 |
+
features["qtc"] = 440.0
|
| 89 |
+
|
| 90 |
+
# ST level (amplitude at J-point, ~40ms after R-peak)
|
| 91 |
+
j_offset = int(0.04 * sample_rate)
|
| 92 |
+
st_levels = []
|
| 93 |
+
for rp in r_peak_indices.astype(int):
|
| 94 |
+
j_idx = rp + j_offset
|
| 95 |
+
if j_idx < len(ecg_cleaned):
|
| 96 |
+
st_levels.append(ecg_cleaned[j_idx])
|
| 97 |
+
features["st_level"] = float(np.mean(st_levels)) if st_levels else 0.0
|
| 98 |
+
|
| 99 |
+
# T-wave amplitude
|
| 100 |
+
t_peaks = [x for x in waves.get("ECG_T_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
|
| 101 |
+
if t_peaks:
|
| 102 |
+
t_amps = [ecg_cleaned[int(t)] for t in t_peaks if int(t) < len(ecg_cleaned)]
|
| 103 |
+
features["t_amplitude"] = float(np.mean(t_amps)) if t_amps else 0.0
|
| 104 |
+
else:
|
| 105 |
+
features["t_amplitude"] = 0.0
|
| 106 |
+
|
| 107 |
+
# P-wave ratio (P amplitude / R amplitude)
|
| 108 |
+
p_peaks = [x for x in waves.get("ECG_P_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
|
| 109 |
+
if p_peaks and features["r_amplitude"] != 0:
|
| 110 |
+
p_amps = [ecg_cleaned[int(p)] for p in p_peaks if int(p) < len(ecg_cleaned)]
|
| 111 |
+
features["p_wave_ratio"] = float(np.mean(p_amps) / features["r_amplitude"]) if p_amps else 0.1
|
| 112 |
+
else:
|
| 113 |
+
features["p_wave_ratio"] = 0.1
|
| 114 |
+
|
| 115 |
+
except Exception:
|
| 116 |
+
features.setdefault("qrs_duration", 100.0)
|
| 117 |
+
features.setdefault("r_amplitude", float(np.max(ecg_cleaned) - np.min(ecg_cleaned)))
|
| 118 |
+
features.setdefault("r_amplitude_std", 0.0)
|
| 119 |
+
features.setdefault("qt_interval", 400.0)
|
| 120 |
+
features.setdefault("qtc", 440.0)
|
| 121 |
+
features.setdefault("st_level", 0.0)
|
| 122 |
+
features.setdefault("t_amplitude", 0.0)
|
| 123 |
+
features.setdefault("p_wave_ratio", 0.1)
|
| 124 |
+
|
| 125 |
+
# --- Signal Statistics (6) ---
|
| 126 |
+
features["rms"] = float(np.sqrt(np.mean(ecg_cleaned ** 2)))
|
| 127 |
+
features["entropy"] = float(_sample_entropy(ecg_cleaned))
|
| 128 |
+
features["zero_crossing_rate"] = float(
|
| 129 |
+
np.sum(np.diff(np.sign(ecg_cleaned - np.mean(ecg_cleaned))) != 0) / len(ecg_cleaned)
|
| 130 |
+
)
|
| 131 |
+
features["kurtosis"] = float(kurtosis(ecg_cleaned))
|
| 132 |
+
features["skewness"] = float(skew(ecg_cleaned))
|
| 133 |
+
features["snr"] = float(_estimate_snr(ecg_cleaned, sample_rate))
|
| 134 |
+
|
| 135 |
+
# --- Device Sensor Features (4) ---
|
| 136 |
+
features["heart_rate_sensor"] = float(heart_rate_sensor) if heart_rate_sensor else features["mean_hr_ecg"]
|
| 137 |
+
features["spo2"] = float(spo2) if spo2 else 97.0
|
| 138 |
+
|
| 139 |
+
hr_diff = abs(features["heart_rate_sensor"] - features["mean_hr_ecg"])
|
| 140 |
+
features["hr_sensor_ecg_diff"] = float(hr_diff)
|
| 141 |
+
|
| 142 |
+
# ECG quality score (based on peak regularity)
|
| 143 |
+
if len(rr_intervals) > 1:
|
| 144 |
+
cv = np.std(rr_intervals) / np.mean(rr_intervals)
|
| 145 |
+
features["ecg_quality"] = float(max(0, 1 - cv))
|
| 146 |
+
else:
|
| 147 |
+
features["ecg_quality"] = 0.5
|
| 148 |
+
|
| 149 |
+
except Exception:
|
| 150 |
+
return _fallback_features(ecg_signal, heart_rate_sensor, spo2)
|
| 151 |
+
|
| 152 |
+
return features
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _sample_entropy(signal, m=2, r_factor=0.2):
|
| 156 |
+
"""Approximate sample entropy."""
|
| 157 |
+
try:
|
| 158 |
+
r = r_factor * np.std(signal)
|
| 159 |
+
N = len(signal)
|
| 160 |
+
if N < m + 2 or r == 0:
|
| 161 |
+
return 0.0
|
| 162 |
+
|
| 163 |
+
# Use simplified approach for speed
|
| 164 |
+
templates_m = np.array([signal[i:i + m] for i in range(N - m)])
|
| 165 |
+
templates_m1 = np.array([signal[i:i + m + 1] for i in range(N - m - 1)])
|
| 166 |
+
|
| 167 |
+
count_m = 0
|
| 168 |
+
count_m1 = 0
|
| 169 |
+
|
| 170 |
+
# Sample subset for speed
|
| 171 |
+
n_check = min(200, len(templates_m))
|
| 172 |
+
indices = np.random.choice(len(templates_m), n_check, replace=False) if len(templates_m) > n_check else range(len(templates_m))
|
| 173 |
+
|
| 174 |
+
for i in indices:
|
| 175 |
+
dist_m = np.max(np.abs(templates_m - templates_m[i]), axis=1)
|
| 176 |
+
count_m += np.sum(dist_m < r) - 1
|
| 177 |
+
|
| 178 |
+
if i < len(templates_m1):
|
| 179 |
+
dist_m1 = np.max(np.abs(templates_m1 - templates_m1[i]), axis=1)
|
| 180 |
+
count_m1 += np.sum(dist_m1 < r) - 1
|
| 181 |
+
|
| 182 |
+
if count_m == 0 or count_m1 == 0:
|
| 183 |
+
return 0.0
|
| 184 |
+
|
| 185 |
+
return -np.log(count_m1 / count_m)
|
| 186 |
+
except Exception:
|
| 187 |
+
return 0.0
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _estimate_snr(signal, sample_rate):
|
| 191 |
+
"""Estimate signal-to-noise ratio."""
|
| 192 |
+
try:
|
| 193 |
+
cleaned = nk.ecg_clean(signal, sampling_rate=sample_rate)
|
| 194 |
+
noise = signal - cleaned
|
| 195 |
+
signal_power = np.mean(cleaned ** 2)
|
| 196 |
+
noise_power = np.mean(noise ** 2)
|
| 197 |
+
if noise_power == 0:
|
| 198 |
+
return 30.0
|
| 199 |
+
return float(10 * np.log10(signal_power / noise_power))
|
| 200 |
+
except Exception:
|
| 201 |
+
return 10.0
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _fallback_features(ecg_signal, heart_rate_sensor=None, spo2=None) -> dict:
|
| 205 |
+
"""Return default features when ECG processing fails."""
|
| 206 |
+
return {
|
| 207 |
+
"mean_rr": 800.0, "sdnn": 50.0, "rmssd": 30.0, "pnn50": 10.0,
|
| 208 |
+
"mean_hr_ecg": 75.0, "hr_std": 5.0, "rr_range": 200.0,
|
| 209 |
+
"qrs_duration": 100.0, "r_amplitude": 1.0, "r_amplitude_std": 0.1,
|
| 210 |
+
"qt_interval": 400.0, "qtc": 440.0, "st_level": 0.0,
|
| 211 |
+
"t_amplitude": 0.3, "p_wave_ratio": 0.1,
|
| 212 |
+
"rms": float(np.sqrt(np.mean(ecg_signal ** 2))) if len(ecg_signal) > 0 else 0.5,
|
| 213 |
+
"entropy": 0.5, "zero_crossing_rate": 0.1,
|
| 214 |
+
"kurtosis": 0.0, "skewness": 0.0, "snr": 10.0,
|
| 215 |
+
"heart_rate_sensor": float(heart_rate_sensor) if heart_rate_sensor else 75.0,
|
| 216 |
+
"spo2": float(spo2) if spo2 else 97.0,
|
| 217 |
+
"hr_sensor_ecg_diff": 0.0, "ecg_quality": 0.5,
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# Ordered feature names for XGBoost (must match training order)
|
| 222 |
+
FEATURE_NAMES = [
|
| 223 |
+
"mean_rr", "sdnn", "rmssd", "pnn50", "mean_hr_ecg", "hr_std", "rr_range",
|
| 224 |
+
"qrs_duration", "r_amplitude", "r_amplitude_std", "qt_interval", "qtc",
|
| 225 |
+
"st_level", "t_amplitude", "p_wave_ratio",
|
| 226 |
+
"rms", "entropy", "zero_crossing_rate", "kurtosis", "skewness", "snr",
|
| 227 |
+
"heart_rate_sensor", "spo2", "hr_sensor_ecg_diff", "ecg_quality",
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
# User profile feature names (appended after ECG features)
|
| 231 |
+
PROFILE_FEATURE_NAMES = [
|
| 232 |
+
"age", "sex", "bmi", "is_diabetic", "is_hypertensive",
|
| 233 |
+
"is_smoker", "family_history",
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
# Historical baseline feature names
|
| 237 |
+
HISTORY_FEATURE_NAMES = [
|
| 238 |
+
"hr_baseline_24h", "hr_baseline_7d", "spo2_baseline_24h",
|
| 239 |
+
"hr_deviation", "spo2_deviation", "resting_hr_trend", "readings_count_24h",
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
ALL_FEATURE_NAMES = FEATURE_NAMES + PROFILE_FEATURE_NAMES + HISTORY_FEATURE_NAMES
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def features_to_array(features: dict, include_profile: bool = False) -> np.ndarray:
|
| 246 |
+
"""Convert feature dict to numpy array in correct order for XGBoost."""
|
| 247 |
+
names = ALL_FEATURE_NAMES if include_profile else FEATURE_NAMES
|
| 248 |
+
return np.array([features.get(name, 0.0) for name in names], dtype=np.float32)
|
ml_src/net1d.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Net1D: 1D CNN with Squeeze-and-Excitation for ECG classification.
|
| 3 |
+
From PKUDigitalHealth/ECGFounder (MIT License).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MyConv1dPadSame(nn.Module):
|
| 12 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.in_channels = in_channels
|
| 15 |
+
self.out_channels = out_channels
|
| 16 |
+
self.kernel_size = kernel_size
|
| 17 |
+
self.stride = stride
|
| 18 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, groups=groups)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
in_dim = x.shape[-1]
|
| 22 |
+
out_dim = (in_dim + self.stride - 1) // self.stride
|
| 23 |
+
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
|
| 24 |
+
pad_left = p // 2
|
| 25 |
+
pad_right = p - pad_left
|
| 26 |
+
x = F.pad(x, (pad_left, pad_right), "constant", 0)
|
| 27 |
+
return self.conv(x)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MyMaxPool1dPadSame(nn.Module):
|
| 31 |
+
def __init__(self, kernel_size):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.kernel_size = kernel_size
|
| 34 |
+
self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
p = max(0, self.kernel_size - 1)
|
| 38 |
+
pad_left = p // 2
|
| 39 |
+
pad_right = p - pad_left
|
| 40 |
+
x = F.pad(x, (pad_left, pad_right), "constant", 0)
|
| 41 |
+
return self.max_pool(x)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Swish(nn.Module):
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return x * torch.sigmoid(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class BasicBlock(nn.Module):
|
| 50 |
+
def __init__(self, in_channels, out_channels, ratio, kernel_size, stride,
|
| 51 |
+
groups, downsample, is_first_block=False, use_bn=True, use_do=True):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.in_channels = in_channels
|
| 54 |
+
self.out_channels = out_channels
|
| 55 |
+
self.downsample = downsample
|
| 56 |
+
self.stride = stride if downsample else 1
|
| 57 |
+
self.is_first_block = is_first_block
|
| 58 |
+
self.use_bn = use_bn
|
| 59 |
+
self.use_do = use_do
|
| 60 |
+
middle = int(out_channels * ratio)
|
| 61 |
+
|
| 62 |
+
self.bn1 = nn.BatchNorm1d(in_channels)
|
| 63 |
+
self.activation1 = Swish()
|
| 64 |
+
self.do1 = nn.Dropout(p=0.5)
|
| 65 |
+
self.conv1 = MyConv1dPadSame(in_channels, middle, 1, 1, 1)
|
| 66 |
+
|
| 67 |
+
self.bn2 = nn.BatchNorm1d(middle)
|
| 68 |
+
self.activation2 = Swish()
|
| 69 |
+
self.do2 = nn.Dropout(p=0.5)
|
| 70 |
+
self.conv2 = MyConv1dPadSame(middle, middle, kernel_size, self.stride, groups)
|
| 71 |
+
|
| 72 |
+
self.bn3 = nn.BatchNorm1d(middle)
|
| 73 |
+
self.activation3 = Swish()
|
| 74 |
+
self.do3 = nn.Dropout(p=0.5)
|
| 75 |
+
self.conv3 = MyConv1dPadSame(middle, out_channels, 1, 1, 1)
|
| 76 |
+
|
| 77 |
+
# Squeeze-and-Excitation
|
| 78 |
+
r = 2
|
| 79 |
+
self.se_fc1 = nn.Linear(out_channels, out_channels // r)
|
| 80 |
+
self.se_fc2 = nn.Linear(out_channels // r, out_channels)
|
| 81 |
+
self.se_activation = Swish()
|
| 82 |
+
|
| 83 |
+
if self.downsample:
|
| 84 |
+
self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
identity = x
|
| 88 |
+
out = x
|
| 89 |
+
|
| 90 |
+
if not self.is_first_block:
|
| 91 |
+
if self.use_bn:
|
| 92 |
+
out = self.bn1(out)
|
| 93 |
+
out = self.activation1(out)
|
| 94 |
+
if self.use_do:
|
| 95 |
+
out = self.do1(out)
|
| 96 |
+
out = self.conv1(out)
|
| 97 |
+
|
| 98 |
+
if self.use_bn:
|
| 99 |
+
out = self.bn2(out)
|
| 100 |
+
out = self.activation2(out)
|
| 101 |
+
if self.use_do:
|
| 102 |
+
out = self.do2(out)
|
| 103 |
+
out = self.conv2(out)
|
| 104 |
+
|
| 105 |
+
if self.use_bn:
|
| 106 |
+
out = self.bn3(out)
|
| 107 |
+
out = self.activation3(out)
|
| 108 |
+
if self.use_do:
|
| 109 |
+
out = self.do3(out)
|
| 110 |
+
out = self.conv3(out)
|
| 111 |
+
|
| 112 |
+
# SE attention
|
| 113 |
+
se = out.mean(-1)
|
| 114 |
+
se = self.se_fc1(se)
|
| 115 |
+
se = self.se_activation(se)
|
| 116 |
+
se = self.se_fc2(se)
|
| 117 |
+
se = torch.sigmoid(se)
|
| 118 |
+
out = torch.einsum('abc,ab->abc', out, se)
|
| 119 |
+
|
| 120 |
+
if self.downsample:
|
| 121 |
+
identity = self.max_pool(identity)
|
| 122 |
+
if self.out_channels != self.in_channels:
|
| 123 |
+
identity = identity.transpose(-1, -2)
|
| 124 |
+
ch1 = (self.out_channels - self.in_channels) // 2
|
| 125 |
+
ch2 = self.out_channels - self.in_channels - ch1
|
| 126 |
+
identity = F.pad(identity, (ch1, ch2), "constant", 0)
|
| 127 |
+
identity = identity.transpose(-1, -2)
|
| 128 |
+
|
| 129 |
+
out += identity
|
| 130 |
+
return out
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class BasicStage(nn.Module):
|
| 134 |
+
def __init__(self, in_channels, out_channels, ratio, kernel_size, stride,
|
| 135 |
+
groups, i_stage, m_blocks, use_bn=True, use_do=True):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.block_list = nn.ModuleList()
|
| 138 |
+
for i_block in range(m_blocks):
|
| 139 |
+
is_first = (i_stage == 0 and i_block == 0)
|
| 140 |
+
if i_block == 0:
|
| 141 |
+
tmp_block = BasicBlock(
|
| 142 |
+
in_channels, out_channels, ratio, kernel_size,
|
| 143 |
+
stride, groups, downsample=True,
|
| 144 |
+
is_first_block=is_first, use_bn=use_bn, use_do=use_do)
|
| 145 |
+
else:
|
| 146 |
+
tmp_block = BasicBlock(
|
| 147 |
+
out_channels, out_channels, ratio, kernel_size,
|
| 148 |
+
1, groups, downsample=False,
|
| 149 |
+
is_first_block=False, use_bn=use_bn, use_do=use_do)
|
| 150 |
+
self.block_list.append(tmp_block)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
for block in self.block_list:
|
| 154 |
+
x = block(x)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class Net1D(nn.Module):
|
| 159 |
+
"""
|
| 160 |
+
1D CNN for ECG classification.
|
| 161 |
+
Input: (batch, in_channels, length)
|
| 162 |
+
Output: (batch, n_classes)
|
| 163 |
+
"""
|
| 164 |
+
def __init__(self, in_channels, base_filters, ratio, filter_list,
|
| 165 |
+
m_blocks_list, kernel_size, stride, groups_width,
|
| 166 |
+
n_classes, use_bn=True, use_do=True, verbose=False):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.n_stages = len(filter_list)
|
| 169 |
+
self.use_bn = use_bn
|
| 170 |
+
|
| 171 |
+
self.first_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size, stride=2)
|
| 172 |
+
self.first_bn = nn.BatchNorm1d(base_filters)
|
| 173 |
+
self.first_activation = Swish()
|
| 174 |
+
|
| 175 |
+
self.stage_list = nn.ModuleList()
|
| 176 |
+
in_ch = base_filters
|
| 177 |
+
for i_stage in range(self.n_stages):
|
| 178 |
+
out_ch = filter_list[i_stage]
|
| 179 |
+
self.stage_list.append(BasicStage(
|
| 180 |
+
in_ch, out_ch, ratio, kernel_size, stride,
|
| 181 |
+
out_ch // groups_width, i_stage, m_blocks_list[i_stage],
|
| 182 |
+
use_bn=use_bn, use_do=use_do))
|
| 183 |
+
in_ch = out_ch
|
| 184 |
+
|
| 185 |
+
self.dense = nn.Linear(in_ch, n_classes)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
out = self.first_conv(x)
|
| 189 |
+
if self.use_bn:
|
| 190 |
+
out = self.first_bn(out)
|
| 191 |
+
out = self.first_activation(out)
|
| 192 |
+
|
| 193 |
+
for stage in self.stage_list:
|
| 194 |
+
out = stage(out)
|
| 195 |
+
|
| 196 |
+
features = out.mean(-1) # Global Average Pooling
|
| 197 |
+
out = self.dense(features)
|
| 198 |
+
return out
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.6
|
| 2 |
+
uvicorn[standard]==0.34.0
|
| 3 |
+
motor==3.6.0
|
| 4 |
+
pymongo==4.10.1
|
| 5 |
+
pydantic==2.10.0
|
| 6 |
+
pydantic-settings==2.7.0
|
| 7 |
+
python-jose[cryptography]==3.3.0
|
| 8 |
+
bcrypt==4.2.0
|
| 9 |
+
python-multipart==0.0.18
|
| 10 |
+
numpy==1.26.4
|
| 11 |
+
scipy==1.14.1
|
| 12 |
+
httpx==0.28.1
|
| 13 |
+
xgboost>=2.0.0
|
| 14 |
+
neurokit2>=0.2.7
|
| 15 |
+
joblib>=1.3.0
|