GeoQuery / backend /core /session_store.py
GerardCB's picture
Deploy to Spaces (Final Clean)
4851501
"""
Session Store Service
Thread-safe session-scoped storage for user layers and context.
Replaces global SESSION_LAYERS with per-session isolation.
"""
import logging
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
class SessionStore:
"""
Thread-safe session-scoped storage with TTL expiration.
Each session maintains its own:
- layers: Map layers created by the user
- context: Optional conversation context
Sessions expire after configurable TTL (default 2 hours).
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(SessionStore, cls).__new__(cls)
cls._instance.initialized = False
return cls._instance
def __init__(self, ttl_hours: int = 2, max_layers_per_session: int = 15):
if self.initialized:
return
self._sessions: Dict[str, dict] = {}
self._lock = threading.Lock()
self.ttl = timedelta(hours=ttl_hours)
self.max_layers = max_layers_per_session
self.initialized = True
logger.info(f"SessionStore initialized with TTL={ttl_hours}h, max_layers={max_layers_per_session}")
def _get_or_create_session(self, session_id: str) -> dict:
"""Get existing session or create new one."""
if session_id not in self._sessions:
self._sessions[session_id] = {
"layers": [],
"created": datetime.now(),
"accessed": datetime.now()
}
return self._sessions[session_id]
def get_layers(self, session_id: str) -> List[dict]:
"""Get all layers for a session."""
with self._lock:
session = self._get_or_create_session(session_id)
session["accessed"] = datetime.now()
return session["layers"].copy()
def add_layer(self, session_id: str, layer: dict) -> None:
"""
Add a layer to a session.
Enforces max_layers limit by removing oldest layers.
"""
with self._lock:
session = self._get_or_create_session(session_id)
session["layers"].append(layer)
session["accessed"] = datetime.now()
# Enforce layer limit
while len(session["layers"]) > self.max_layers:
removed = session["layers"].pop(0)
logger.debug(f"Session {session_id[:8]}: removed oldest layer {removed.get('name', 'unknown')}")
def update_layer(self, session_id: str, layer_id: str, updates: dict) -> bool:
"""
Update an existing layer in a session.
Returns True if layer was found and updated.
"""
with self._lock:
session = self._sessions.get(session_id)
if not session:
return False
for layer in session["layers"]:
if layer.get("id") == layer_id:
layer.update(updates)
session["accessed"] = datetime.now()
return True
return False
def remove_layer(self, session_id: str, layer_id: str) -> bool:
"""
Remove a layer from a session.
Returns True if layer was found and removed.
"""
with self._lock:
session = self._sessions.get(session_id)
if not session:
return False
original_len = len(session["layers"])
session["layers"] = [l for l in session["layers"] if l.get("id") != layer_id]
session["accessed"] = datetime.now()
return len(session["layers"]) < original_len
def clear_session(self, session_id: str) -> None:
"""Clear all data for a session."""
with self._lock:
if session_id in self._sessions:
del self._sessions[session_id]
def get_layer_by_index(self, session_id: str, index: int) -> Optional[dict]:
"""Get a specific layer by 1-based index (for user references like 'Layer 1')."""
with self._lock:
session = self._sessions.get(session_id)
if not session:
return None
layers = session["layers"]
if 1 <= index <= len(layers):
return layers[index - 1].copy()
return None
def cleanup_expired(self) -> int:
"""
Remove sessions older than TTL.
Returns number of expired sessions removed.
"""
with self._lock:
now = datetime.now()
expired = [
sid for sid, data in self._sessions.items()
if now - data.get("accessed", data["created"]) > self.ttl
]
for sid in expired:
del self._sessions[sid]
if expired:
logger.info(f"Cleaned up {len(expired)} expired sessions.")
return len(expired)
def get_stats(self) -> dict:
"""Return statistics about active sessions."""
with self._lock:
total_layers = sum(len(s["layers"]) for s in self._sessions.values())
return {
"active_sessions": len(self._sessions),
"total_layers": total_layers,
"ttl_hours": self.ttl.total_seconds() / 3600,
"max_layers_per_session": self.max_layers
}
# Singleton accessor
_session_store: Optional[SessionStore] = None
def get_session_store() -> SessionStore:
"""Get the singleton session store instance."""
global _session_store
if _session_store is None:
_session_store = SessionStore()
return _session_store