Spaces:
Running
Running
| """ | |
| database/db.py | |
| ============== | |
| SQLite persistence layer for user accounts and stress analysis sessions. | |
| Tables | |
| ------ | |
| - ``users`` — username, password_hash, encrypted_history, created_at | |
| - ``sessions`` — per-analysis snapshots linked to users | |
| Thread-safety is handled by using ``check_same_thread=False`` and | |
| relying on SQLite's internal serialisation (WAL mode). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import sqlite3 | |
| import time | |
| from typing import Any, Optional | |
| logger = logging.getLogger(__name__) | |
| # Default database path (overridable via env var or constructor arg) | |
| _DEFAULT_DB_PATH = os.environ.get("STRESS_DB_PATH", "stress_detection.db") | |
| class DatabaseManager: | |
| """Thin wrapper around a SQLite database for user + session storage. | |
| Parameters | |
| ---------- | |
| db_path : str | |
| File path for the SQLite database. Use ``":memory:"`` for an | |
| ephemeral in-memory database (useful in tests). | |
| """ | |
| def __init__(self, db_path: str = _DEFAULT_DB_PATH) -> None: | |
| self._db_path = db_path | |
| self._conn = sqlite3.connect( | |
| db_path, check_same_thread=False, | |
| ) | |
| self._conn.row_factory = sqlite3.Row | |
| self._conn.execute("PRAGMA journal_mode=WAL") | |
| self._conn.execute("PRAGMA foreign_keys=ON") | |
| self._create_tables() | |
| # ------------------------------------------------------------------ | |
| # Schema | |
| # ------------------------------------------------------------------ | |
| def _create_tables(self) -> None: | |
| """Create tables if they do not exist.""" | |
| self._conn.executescript( | |
| """ | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT NOT NULL UNIQUE, | |
| password_hash TEXT NOT NULL, | |
| encrypted_history TEXT, | |
| created_at REAL NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER NOT NULL, | |
| stress_score REAL NOT NULL, | |
| stress_label TEXT NOT NULL, | |
| temporal_data TEXT NOT NULL, | |
| interventions TEXT NOT NULL, | |
| is_crisis INTEGER NOT NULL DEFAULT 0, | |
| crisis_message TEXT, | |
| matched_triggers TEXT NOT NULL, | |
| attention_weights TEXT NOT NULL, | |
| created_at REAL NOT NULL, | |
| FOREIGN KEY (user_id) REFERENCES users(id) | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_sessions_user_id | |
| ON sessions(user_id); | |
| CREATE INDEX IF NOT EXISTS idx_sessions_created_at | |
| ON sessions(created_at); | |
| """ | |
| ) | |
| # ------------------------------------------------------------------ | |
| # User CRUD | |
| # ------------------------------------------------------------------ | |
| def create_user( | |
| self, | |
| username: str, | |
| password_hash: str, | |
| ) -> int: | |
| """Insert a new user and return their ``id``. | |
| Raises | |
| ------ | |
| sqlite3.IntegrityError | |
| If the username already exists. | |
| """ | |
| cur = self._conn.execute( | |
| "INSERT INTO users (username, password_hash, encrypted_history, created_at) " | |
| "VALUES (?, ?, NULL, ?)", | |
| (username, password_hash, time.time()), | |
| ) | |
| self._conn.commit() | |
| return cur.lastrowid # type: ignore[return-value] | |
| def get_user(self, username: str) -> Optional[dict[str, Any]]: | |
| """Return a user dict or ``None`` if not found.""" | |
| row = self._conn.execute( | |
| "SELECT id, username, password_hash, encrypted_history, created_at " | |
| "FROM users WHERE username = ?", | |
| (username,), | |
| ).fetchone() | |
| if row is None: | |
| return None | |
| return dict(row) | |
| def user_exists(self, username: str) -> bool: | |
| """Check whether a username is already taken.""" | |
| row = self._conn.execute( | |
| "SELECT 1 FROM users WHERE username = ?", (username,) | |
| ).fetchone() | |
| return row is not None | |
| def update_encrypted_history( | |
| self, username: str, encrypted_history: str | |
| ) -> None: | |
| """Persist the user's updated encrypted temporal history.""" | |
| self._conn.execute( | |
| "UPDATE users SET encrypted_history = ? WHERE username = ?", | |
| (encrypted_history, username), | |
| ) | |
| self._conn.commit() | |
| # ------------------------------------------------------------------ | |
| # Session CRUD | |
| # ------------------------------------------------------------------ | |
| def save_session( | |
| self, | |
| username: str, | |
| stress_score: float, | |
| stress_label: str, | |
| temporal_data: dict, | |
| interventions: list[dict], | |
| is_crisis: bool, | |
| crisis_message: Optional[str], | |
| matched_triggers: list[str], | |
| attention_weights: list[float], | |
| ) -> int: | |
| """Persist a single analysis session and return its ``id``.""" | |
| user = self.get_user(username) | |
| if user is None: | |
| raise ValueError(f"User '{username}' not found") | |
| cur = self._conn.execute( | |
| "INSERT INTO sessions " | |
| "(user_id, stress_score, stress_label, temporal_data, " | |
| " interventions, is_crisis, crisis_message, matched_triggers, " | |
| " attention_weights, created_at) " | |
| "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | |
| ( | |
| user["id"], | |
| stress_score, | |
| stress_label, | |
| json.dumps(temporal_data), | |
| json.dumps(interventions), | |
| int(is_crisis), | |
| crisis_message, | |
| json.dumps(matched_triggers), | |
| json.dumps(attention_weights), | |
| time.time(), | |
| ), | |
| ) | |
| self._conn.commit() | |
| return cur.lastrowid # type: ignore[return-value] | |
| def get_sessions( | |
| self, | |
| username: str, | |
| limit: int = 50, | |
| offset: int = 0, | |
| ) -> list[dict[str, Any]]: | |
| """Return past sessions for a user, newest first. | |
| Parameters | |
| ---------- | |
| username : str | |
| The user whose sessions to retrieve. | |
| limit : int | |
| Maximum number of sessions to return. | |
| offset : int | |
| Number of sessions to skip (for pagination). | |
| """ | |
| user = self.get_user(username) | |
| if user is None: | |
| return [] | |
| rows = self._conn.execute( | |
| "SELECT id, stress_score, stress_label, temporal_data, " | |
| "interventions, is_crisis, crisis_message, matched_triggers, " | |
| "attention_weights, created_at " | |
| "FROM sessions WHERE user_id = ? " | |
| "ORDER BY created_at DESC LIMIT ? OFFSET ?", | |
| (user["id"], limit, offset), | |
| ).fetchall() | |
| sessions = [] | |
| for row in rows: | |
| session = dict(row) | |
| session["temporal_data"] = json.loads(session["temporal_data"]) | |
| session["interventions"] = json.loads(session["interventions"]) | |
| session["is_crisis"] = bool(session["is_crisis"]) | |
| session["matched_triggers"] = json.loads(session["matched_triggers"]) | |
| session["attention_weights"] = json.loads(session["attention_weights"]) | |
| sessions.append(session) | |
| return sessions | |
| def get_session_count(self, username: str) -> int: | |
| """Return the total number of sessions for a user.""" | |
| user = self.get_user(username) | |
| if user is None: | |
| return 0 | |
| row = self._conn.execute( | |
| "SELECT COUNT(*) as cnt FROM sessions WHERE user_id = ?", | |
| (user["id"],), | |
| ).fetchone() | |
| return row["cnt"] | |
| # ------------------------------------------------------------------ | |
| # Lifecycle | |
| # ------------------------------------------------------------------ | |
| def close(self) -> None: | |
| """Close the database connection.""" | |
| self._conn.close() | |