soci2 / src /soci /persistence /database.py
RayMelius's picture
Persist LLM probability across restarts; fix Gemini model fallback
0e4c818
"""Database β€” SQLite persistence for simulation state."""
from __future__ import annotations
import hashlib
import json
import logging
import os
import secrets
from pathlib import Path
from typing import Optional
import aiosqlite
logger = logging.getLogger(__name__)
def _hash_password(password: str, salt: str) -> str:
return hashlib.sha256(f"{salt}{password}".encode()).hexdigest()
# SOCI_DATA_DIR env var lets you point at a persistent disk (e.g. /var/data on Render).
DB_DIR = Path(os.environ.get("SOCI_DATA_DIR", "data"))
DEFAULT_DB = DB_DIR / "soci.db"
SCHEMA = """
CREATE TABLE IF NOT EXISTS snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
tick INTEGER NOT NULL,
day INTEGER NOT NULL,
state_json TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS event_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tick INTEGER NOT NULL,
day INTEGER NOT NULL,
time_str TEXT NOT NULL,
event_type TEXT NOT NULL,
agent_id TEXT,
location TEXT,
description TEXT NOT NULL,
metadata_json TEXT
);
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conv_id TEXT NOT NULL,
tick INTEGER NOT NULL,
day INTEGER NOT NULL,
location TEXT NOT NULL,
participants_json TEXT NOT NULL,
topic TEXT,
turns_json TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_event_tick ON event_log(tick);
CREATE INDEX IF NOT EXISTS idx_event_agent ON event_log(agent_id);
CREATE INDEX IF NOT EXISTS idx_conv_tick ON conversations(tick);
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
salt TEXT NOT NULL,
token TEXT UNIQUE,
agent_id TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
class Database:
"""Async SQLite database for simulation persistence."""
def __init__(self, db_path: str | Path = DEFAULT_DB) -> None:
self.db_path = Path(db_path)
self._db: Optional[aiosqlite.Connection] = None
async def connect(self) -> None:
"""Connect to the database and create tables."""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._db = await aiosqlite.connect(str(self.db_path))
await self._db.executescript(SCHEMA)
await self._db.commit()
logger.info(f"Database connected: {self.db_path}")
async def close(self) -> None:
if self._db:
await self._db.close()
async def save_snapshot(self, name: str, tick: int, day: int, state: dict) -> int:
"""Save a full simulation state snapshot."""
assert self._db is not None
cursor = await self._db.execute(
"INSERT INTO snapshots (name, tick, day, state_json) VALUES (?, ?, ?, ?)",
(name, tick, day, json.dumps(state)),
)
await self._db.commit()
return cursor.lastrowid
async def load_snapshot(self, name: Optional[str] = None) -> Optional[dict]:
"""Load the latest snapshot, or a specific named one."""
assert self._db is not None
if name:
cursor = await self._db.execute(
"SELECT state_json FROM snapshots WHERE name = ? ORDER BY id DESC LIMIT 1",
(name,),
)
else:
cursor = await self._db.execute(
"SELECT state_json FROM snapshots ORDER BY id DESC LIMIT 1",
)
row = await cursor.fetchone()
if row:
try:
return json.loads(row[0])
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Corrupt snapshot in DB, ignoring: {e}")
return None
async def list_snapshots(self) -> list[dict]:
"""List all saved snapshots."""
assert self._db is not None
cursor = await self._db.execute(
"SELECT id, name, created_at, tick, day FROM snapshots ORDER BY id DESC"
)
rows = await cursor.fetchall()
return [
{"id": r[0], "name": r[1], "created_at": r[2], "tick": r[3], "day": r[4]}
for r in rows
]
async def log_event(
self,
tick: int,
day: int,
time_str: str,
event_type: str,
description: str,
agent_id: str = "",
location: str = "",
metadata: Optional[dict] = None,
) -> None:
"""Log a simulation event."""
assert self._db is not None
await self._db.execute(
"INSERT INTO event_log (tick, day, time_str, event_type, agent_id, location, description, metadata_json) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(tick, day, time_str, event_type, agent_id, location, description,
json.dumps(metadata) if metadata else None),
)
await self._db.commit()
async def save_conversation(self, conv_data: dict) -> None:
"""Save a completed conversation."""
assert self._db is not None
await self._db.execute(
"INSERT INTO conversations (conv_id, tick, day, location, participants_json, topic, turns_json) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(
conv_data["id"],
conv_data["turns"][-1]["tick"] if conv_data["turns"] else 0,
0, # Day would be tracked from clock
conv_data["location"],
json.dumps(conv_data["participants"]),
conv_data.get("topic", ""),
json.dumps(conv_data["turns"]),
),
)
await self._db.commit()
async def get_recent_events(self, limit: int = 50) -> list[dict]:
"""Get recent events from the log."""
assert self._db is not None
cursor = await self._db.execute(
"SELECT tick, day, time_str, event_type, agent_id, location, description "
"FROM event_log ORDER BY id DESC LIMIT ?",
(limit,),
)
rows = await cursor.fetchall()
return [
{
"tick": r[0], "day": r[1], "time_str": r[2],
"event_type": r[3], "agent_id": r[4],
"location": r[5], "description": r[6],
}
for r in rows
]
# ── Auth / user methods ──────────────────────────────────────────────────
async def create_user(self, username: str, password: str) -> dict:
"""Create a new user. Raises ValueError if username taken."""
assert self._db is not None
salt = secrets.token_hex(16)
pw_hash = _hash_password(password, salt)
token = secrets.token_hex(32)
try:
await self._db.execute(
"INSERT INTO users (username, password_hash, salt, token) VALUES (?, ?, ?, ?)",
(username, pw_hash, salt, token),
)
await self._db.commit()
except aiosqlite.IntegrityError:
raise ValueError(f"Username '{username}' is already taken")
return {"username": username, "token": token, "agent_id": None}
async def authenticate_user(self, username: str, password: str) -> Optional[dict]:
"""Verify credentials and return user dict with fresh token, or None."""
assert self._db is not None
cursor = await self._db.execute(
"SELECT username, password_hash, salt, agent_id FROM users WHERE username = ?",
(username,),
)
row = await cursor.fetchone()
if not row:
return None
stored_hash = row[1]
salt = row[2]
if _hash_password(password, salt) != stored_hash:
return None
token = secrets.token_hex(32)
await self._db.execute(
"UPDATE users SET token = ? WHERE username = ?", (token, username)
)
await self._db.commit()
return {"username": row[0], "token": token, "agent_id": row[3]}
async def get_user_by_token(self, token: str) -> Optional[dict]:
"""Look up a user by their session token."""
assert self._db is not None
cursor = await self._db.execute(
"SELECT username, agent_id FROM users WHERE token = ?", (token,)
)
row = await cursor.fetchone()
if not row:
return None
return {"username": row[0], "agent_id": row[1]}
async def set_user_agent(self, username: str, agent_id: str) -> None:
"""Link a player agent to a user account."""
assert self._db is not None
await self._db.execute(
"UPDATE users SET agent_id = ? WHERE username = ?", (agent_id, username)
)
await self._db.commit()
async def logout_user(self, token: str) -> None:
"""Invalidate a session token."""
assert self._db is not None
await self._db.execute("UPDATE users SET token = NULL WHERE token = ?", (token,))
await self._db.commit()
# ── Settings / persistent config ─────────────────────────────────────────
async def get_setting(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Read a persisted setting by key."""
assert self._db is not None
cursor = await self._db.execute("SELECT value FROM settings WHERE key = ?", (key,))
row = await cursor.fetchone()
return row[0] if row else default
async def set_setting(self, key: str, value: str) -> None:
"""Upsert a persisted setting."""
assert self._db is not None
await self._db.execute(
"INSERT INTO settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
(key, value),
)
await self._db.commit()