|
|
""" |
|
|
Session Store Service |
|
|
|
|
|
Thread-safe session-scoped storage for user layers and context. |
|
|
Replaces global SESSION_LAYERS with per-session isolation. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import threading |
|
|
from datetime import datetime, timedelta |
|
|
from typing import Dict, List, Optional, Any |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SessionStore: |
|
|
""" |
|
|
Thread-safe session-scoped storage with TTL expiration. |
|
|
|
|
|
Each session maintains its own: |
|
|
- layers: Map layers created by the user |
|
|
- context: Optional conversation context |
|
|
|
|
|
Sessions expire after configurable TTL (default 2 hours). |
|
|
""" |
|
|
|
|
|
_instance = None |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super(SessionStore, cls).__new__(cls) |
|
|
cls._instance.initialized = False |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self, ttl_hours: int = 2, max_layers_per_session: int = 15): |
|
|
if self.initialized: |
|
|
return |
|
|
|
|
|
self._sessions: Dict[str, dict] = {} |
|
|
self._lock = threading.Lock() |
|
|
self.ttl = timedelta(hours=ttl_hours) |
|
|
self.max_layers = max_layers_per_session |
|
|
self.initialized = True |
|
|
|
|
|
logger.info(f"SessionStore initialized with TTL={ttl_hours}h, max_layers={max_layers_per_session}") |
|
|
|
|
|
def _get_or_create_session(self, session_id: str) -> dict: |
|
|
"""Get existing session or create new one.""" |
|
|
if session_id not in self._sessions: |
|
|
self._sessions[session_id] = { |
|
|
"layers": [], |
|
|
"created": datetime.now(), |
|
|
"accessed": datetime.now() |
|
|
} |
|
|
return self._sessions[session_id] |
|
|
|
|
|
def get_layers(self, session_id: str) -> List[dict]: |
|
|
"""Get all layers for a session.""" |
|
|
with self._lock: |
|
|
session = self._get_or_create_session(session_id) |
|
|
session["accessed"] = datetime.now() |
|
|
return session["layers"].copy() |
|
|
|
|
|
def add_layer(self, session_id: str, layer: dict) -> None: |
|
|
""" |
|
|
Add a layer to a session. |
|
|
|
|
|
Enforces max_layers limit by removing oldest layers. |
|
|
""" |
|
|
with self._lock: |
|
|
session = self._get_or_create_session(session_id) |
|
|
session["layers"].append(layer) |
|
|
session["accessed"] = datetime.now() |
|
|
|
|
|
|
|
|
while len(session["layers"]) > self.max_layers: |
|
|
removed = session["layers"].pop(0) |
|
|
logger.debug(f"Session {session_id[:8]}: removed oldest layer {removed.get('name', 'unknown')}") |
|
|
|
|
|
def update_layer(self, session_id: str, layer_id: str, updates: dict) -> bool: |
|
|
""" |
|
|
Update an existing layer in a session. |
|
|
|
|
|
Returns True if layer was found and updated. |
|
|
""" |
|
|
with self._lock: |
|
|
session = self._sessions.get(session_id) |
|
|
if not session: |
|
|
return False |
|
|
|
|
|
for layer in session["layers"]: |
|
|
if layer.get("id") == layer_id: |
|
|
layer.update(updates) |
|
|
session["accessed"] = datetime.now() |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def remove_layer(self, session_id: str, layer_id: str) -> bool: |
|
|
""" |
|
|
Remove a layer from a session. |
|
|
|
|
|
Returns True if layer was found and removed. |
|
|
""" |
|
|
with self._lock: |
|
|
session = self._sessions.get(session_id) |
|
|
if not session: |
|
|
return False |
|
|
|
|
|
original_len = len(session["layers"]) |
|
|
session["layers"] = [l for l in session["layers"] if l.get("id") != layer_id] |
|
|
session["accessed"] = datetime.now() |
|
|
|
|
|
return len(session["layers"]) < original_len |
|
|
|
|
|
def clear_session(self, session_id: str) -> None: |
|
|
"""Clear all data for a session.""" |
|
|
with self._lock: |
|
|
if session_id in self._sessions: |
|
|
del self._sessions[session_id] |
|
|
|
|
|
def get_layer_by_index(self, session_id: str, index: int) -> Optional[dict]: |
|
|
"""Get a specific layer by 1-based index (for user references like 'Layer 1').""" |
|
|
with self._lock: |
|
|
session = self._sessions.get(session_id) |
|
|
if not session: |
|
|
return None |
|
|
|
|
|
layers = session["layers"] |
|
|
if 1 <= index <= len(layers): |
|
|
return layers[index - 1].copy() |
|
|
|
|
|
return None |
|
|
|
|
|
def cleanup_expired(self) -> int: |
|
|
""" |
|
|
Remove sessions older than TTL. |
|
|
|
|
|
Returns number of expired sessions removed. |
|
|
""" |
|
|
with self._lock: |
|
|
now = datetime.now() |
|
|
expired = [ |
|
|
sid for sid, data in self._sessions.items() |
|
|
if now - data.get("accessed", data["created"]) > self.ttl |
|
|
] |
|
|
|
|
|
for sid in expired: |
|
|
del self._sessions[sid] |
|
|
|
|
|
if expired: |
|
|
logger.info(f"Cleaned up {len(expired)} expired sessions.") |
|
|
|
|
|
return len(expired) |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Return statistics about active sessions.""" |
|
|
with self._lock: |
|
|
total_layers = sum(len(s["layers"]) for s in self._sessions.values()) |
|
|
|
|
|
return { |
|
|
"active_sessions": len(self._sessions), |
|
|
"total_layers": total_layers, |
|
|
"ttl_hours": self.ttl.total_seconds() / 3600, |
|
|
"max_layers_per_session": self.max_layers |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_session_store: Optional[SessionStore] = None |
|
|
|
|
|
|
|
|
def get_session_store() -> SessionStore: |
|
|
"""Get the singleton session store instance.""" |
|
|
global _session_store |
|
|
if _session_store is None: |
|
|
_session_store = SessionStore() |
|
|
return _session_store |
|
|
|