| """Session manager for handling multiple concurrent agent sessions.""" |
|
|
| import asyncio |
| import logging |
| import uuid |
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any, Optional |
| import json |
|
|
| from websocket import manager as ws_manager |
|
|
| from agent.config import load_config |
| from agent.core.agent_loop import process_submission |
| from agent.core.session import Event, OpType, Session |
| from agent.core.tools import ToolRouter |
|
|
| |
| def _get_config_path(): |
| |
| backend_root = Path(__file__).parent |
| local_config = backend_root / "configs" / "main_agent_config.json" |
| if local_config.exists(): |
| return str(local_config) |
| |
| |
| docker_config = Path("/app/configs/main_agent_config.json") |
| if docker_config.exists(): |
| return str(docker_config) |
| |
| |
| project_root = Path(__file__).parent.parent |
| return str(project_root / "configs" / "main_agent_config.json") |
|
|
| DEFAULT_CONFIG_PATH = _get_config_path() |
|
|
|
|
| |
| @dataclass |
| class Operation: |
| """Operation to be executed by the agent.""" |
|
|
| op_type: OpType |
| data: Optional[dict[str, Any]] = None |
|
|
|
|
| @dataclass |
| class Submission: |
| """Submission to the agent loop.""" |
|
|
| id: str |
| operation: Operation |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class AgentSession: |
| """Wrapper for an agent session with its associated resources.""" |
|
|
| session_id: str |
| session: Session |
| tool_router: ToolRouter |
| submission_queue: asyncio.Queue |
| user_id: str = "dev" |
| hf_token: str | None = None |
| task: asyncio.Task | None = None |
| created_at: datetime = field(default_factory=datetime.utcnow) |
| is_active: bool = True |
|
|
|
|
| class SessionCapacityError(Exception): |
| """Raised when no more sessions can be created.""" |
|
|
| def __init__(self, message: str, error_type: str = "global") -> None: |
| super().__init__(message) |
| self.error_type = error_type |
|
|
|
|
| |
| |
| |
| MAX_SESSIONS: int = 50 |
| MAX_SESSIONS_PER_USER: int = 10 |
|
|
|
|
| class SessionManager: |
| """Manages multiple concurrent agent sessions.""" |
|
|
| def __init__(self, config_path: str | None = None) -> None: |
| self.config = load_config(config_path or DEFAULT_CONFIG_PATH) |
| self.sessions: dict[str, AgentSession] = {} |
| self._lock = asyncio.Lock() |
|
|
| def _count_user_sessions(self, user_id: str) -> int: |
| """Count active sessions owned by a specific user.""" |
| return sum( |
| 1 |
| for s in self.sessions.values() |
| if s.user_id == user_id and s.is_active |
| ) |
|
|
| async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str: |
| """Create a new agent session and return its ID. |
| |
| Session() and ToolRouter() constructors contain blocking I/O |
| (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are |
| executed in a thread pool to avoid freezing the async event loop. |
| |
| Args: |
| user_id: The ID of the user who owns this session. |
| |
| Raises: |
| SessionCapacityError: If the server or user has reached the |
| maximum number of concurrent sessions. |
| """ |
| |
| async with self._lock: |
| active_count = self.active_session_count |
| if active_count >= MAX_SESSIONS: |
| raise SessionCapacityError( |
| f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " |
| "Please try again later.", |
| error_type="global", |
| ) |
| if user_id != "dev": |
| user_count = self._count_user_sessions(user_id) |
| if user_count >= MAX_SESSIONS_PER_USER: |
| raise SessionCapacityError( |
| f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " |
| "concurrent sessions. Please close an existing session first.", |
| error_type="per_user", |
| ) |
|
|
| session_id = str(uuid.uuid4()) |
|
|
| |
| sessions_root = Path(__file__).parent / "sessions" |
| sessions_root.mkdir(parents=True, exist_ok=True) |
| session_folder = sessions_root / session_id |
| session_folder.mkdir(parents=True, exist_ok=True) |
| (session_folder / "files").mkdir(exist_ok=True) |
| (session_folder / "documents").mkdir(exist_ok=True) |
|
|
| |
| metadata = { |
| "session_id": session_id, |
| "user_id": user_id, |
| "created_at": datetime.utcnow().isoformat(), |
| "root_path": str(session_folder) |
| } |
| metadata_file = session_folder / "metadata.json" |
| metadata_file.write_text(json.dumps(metadata, indent=2)) |
|
|
| logger.info(f"Created session folder: {session_folder}") |
|
|
| |
| submission_queue: asyncio.Queue = asyncio.Queue() |
| event_queue: asyncio.Queue = asyncio.Queue() |
|
|
| |
| |
| |
| import time as _time |
|
|
| def _create_session_sync(): |
| t0 = _time.monotonic() |
| tool_router = ToolRouter(self.config.mcpServers) |
| session = Session(event_queue, config=self.config, tool_router=tool_router) |
| t1 = _time.monotonic() |
| logger.info(f"Session initialized in {t1 - t0:.2f}s") |
| return tool_router, session |
|
|
| tool_router, session = await asyncio.to_thread(_create_session_sync) |
|
|
| |
| session.hf_token = hf_token |
|
|
| |
| session.session_folder = str(session_folder) |
|
|
| |
| agent_session = AgentSession( |
| session_id=session_id, |
| session=session, |
| tool_router=tool_router, |
| submission_queue=submission_queue, |
| user_id=user_id, |
| hf_token=hf_token, |
| ) |
|
|
| async with self._lock: |
| self.sessions[session_id] = agent_session |
|
|
| |
| task = asyncio.create_task( |
| self._run_session(session_id, submission_queue, event_queue, tool_router) |
| ) |
| agent_session.task = task |
|
|
| logger.info(f"Created session {session_id} for user {user_id}") |
| return session_id |
|
|
| async def _run_session( |
| self, |
| session_id: str, |
| submission_queue: asyncio.Queue, |
| event_queue: asyncio.Queue, |
| tool_router: ToolRouter, |
| ) -> None: |
| """Run the agent loop for a session and forward events to WebSocket.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| logger.error(f"Session {session_id} not found") |
| return |
|
|
| session = agent_session.session |
|
|
| |
| event_forwarder = asyncio.create_task( |
| self._forward_events(session_id, event_queue) |
| ) |
|
|
| try: |
| async with tool_router: |
| |
| await session.send_event( |
| Event(event_type="ready", data={"message": "Agent initialized"}) |
| ) |
|
|
| while session.is_running: |
| try: |
| |
| submission = await asyncio.wait_for( |
| submission_queue.get(), timeout=1.0 |
| ) |
| should_continue = await process_submission(session, submission) |
| if not should_continue: |
| break |
| except asyncio.TimeoutError: |
| continue |
| except asyncio.CancelledError: |
| logger.info(f"Session {session_id} cancelled") |
| break |
| except Exception as e: |
| logger.error(f"Error in session {session_id}: {e}") |
| await session.send_event( |
| Event(event_type="error", data={"error": str(e)}) |
| ) |
|
|
| finally: |
| event_forwarder.cancel() |
| try: |
| await event_forwarder |
| except asyncio.CancelledError: |
| pass |
|
|
| async with self._lock: |
| if session_id in self.sessions: |
| self.sessions[session_id].is_active = False |
|
|
| logger.info(f"Session {session_id} ended") |
|
|
| async def _forward_events( |
| self, session_id: str, event_queue: asyncio.Queue |
| ) -> None: |
| """Forward events from the agent to the WebSocket.""" |
| while True: |
| try: |
| event: Event = await event_queue.get() |
| await ws_manager.send_event(session_id, event.event_type, event.data) |
| except asyncio.CancelledError: |
| break |
| except Exception as e: |
| logger.error(f"Error forwarding event for {session_id}: {e}") |
|
|
| async def submit(self, session_id: str, operation: Operation) -> bool: |
| """Submit an operation to a session.""" |
| async with self._lock: |
| agent_session = self.sessions.get(session_id) |
|
|
| if not agent_session or not agent_session.is_active: |
| logger.warning(f"Session {session_id} not found or inactive") |
| return False |
|
|
| submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) |
| await agent_session.submission_queue.put(submission) |
| return True |
|
|
| async def submit_user_input(self, session_id: str, text: str) -> bool: |
| """Submit user input to a session.""" |
| operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) |
| return await self.submit(session_id, operation) |
|
|
| async def submit_approval( |
| self, session_id: str, approvals: list[dict[str, Any]] |
| ) -> bool: |
| """Submit tool approvals to a session.""" |
| operation = Operation( |
| op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} |
| ) |
| return await self.submit(session_id, operation) |
|
|
| async def interrupt(self, session_id: str) -> bool: |
| """Interrupt a session.""" |
| operation = Operation(op_type=OpType.INTERRUPT) |
| return await self.submit(session_id, operation) |
|
|
| async def undo(self, session_id: str) -> bool: |
| """Undo last turn in a session.""" |
| operation = Operation(op_type=OpType.UNDO) |
| return await self.submit(session_id, operation) |
|
|
| async def compact(self, session_id: str) -> bool: |
| """Compact context in a session.""" |
| operation = Operation(op_type=OpType.COMPACT) |
| return await self.submit(session_id, operation) |
|
|
| async def shutdown_session(self, session_id: str) -> bool: |
| """Shutdown a specific session.""" |
| operation = Operation(op_type=OpType.SHUTDOWN) |
| success = await self.submit(session_id, operation) |
|
|
| if success: |
| async with self._lock: |
| agent_session = self.sessions.get(session_id) |
| if agent_session and agent_session.task: |
| |
| try: |
| await asyncio.wait_for(agent_session.task, timeout=5.0) |
| except asyncio.TimeoutError: |
| agent_session.task.cancel() |
|
|
| return success |
|
|
| async def delete_session(self, session_id: str) -> bool: |
| """Delete a session entirely.""" |
| async with self._lock: |
| agent_session = self.sessions.pop(session_id, None) |
|
|
| if not agent_session: |
| return False |
|
|
| |
| if agent_session.task and not agent_session.task.done(): |
| agent_session.task.cancel() |
| try: |
| await agent_session.task |
| except asyncio.CancelledError: |
| pass |
|
|
| return True |
|
|
| def get_session_owner(self, session_id: str) -> str | None: |
| """Get the user_id that owns a session, or None if session doesn't exist.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| return None |
| return agent_session.user_id |
|
|
| def verify_session_access(self, session_id: str, user_id: str) -> bool: |
| """Check if a user has access to a session. |
| |
| Returns True if: |
| - The session exists AND the user owns it |
| - The user_id is "dev" (dev mode bypass) |
| """ |
| owner = self.get_session_owner(session_id) |
| if owner is None: |
| return False |
| if user_id == "dev" or owner == "dev": |
| return True |
| return owner == user_id |
|
|
| def get_session_info(self, session_id: str) -> dict[str, Any] | None: |
| """Get information about a session.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| return None |
|
|
| return { |
| "session_id": session_id, |
| "created_at": agent_session.created_at.isoformat(), |
| "is_active": agent_session.is_active, |
| "message_count": len(agent_session.session.context_manager.items), |
| "user_id": agent_session.user_id, |
| } |
|
|
| def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: |
| """List sessions, optionally filtered by user. |
| |
| Args: |
| user_id: If provided, only return sessions owned by this user. |
| If "dev", return all sessions (dev mode). |
| """ |
| results = [] |
| for sid in self.sessions: |
| info = self.get_session_info(sid) |
| if not info: |
| continue |
| if user_id and user_id != "dev" and info.get("user_id") != user_id: |
| continue |
| results.append(info) |
| return results |
|
|
| @property |
| def active_session_count(self) -> int: |
| """Get count of active sessions.""" |
| return sum(1 for s in self.sessions.values() if s.is_active) |
|
|
|
|
| |
| session_manager = SessionManager() |
|
|