OpenSpace / openspace /grounding /core /provider.py
darkfire514's picture
Upload 160 files
399b80c verified
"""
provider is to manage sessions of a backend, if the backend is mcp, then provider will manage sessions through servers
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Generic, TypeVar
from .tool import BaseTool
from .types import BackendType, SessionConfig, ToolResult, ToolStatus
from .session import BaseSession
from .security.policies import SecurityPolicyManager
from openspace.config import get_config
from openspace.utils.logging import Logger
logger = Logger.get_logger(__name__)
TSession = TypeVar('TSession', bound=BaseSession)
class Provider(ABC, Generic[TSession]):
"""Backend provider base class"""
def __init__(self, backend_type: BackendType, config: Dict[str, Any] = None):
self.backend_type = backend_type
self.config = config or {}
self.is_initialized = False
self._sessions: Dict[str, TSession] = {} # session management
self._session_counter: int = 0
self.security_manager = SecurityPolicyManager()
self._setup_security_policy(config)
def _setup_security_policy(self, config: dict | None = None):
security_policy = get_config().get_security_policy(self.backend_type.value)
self.security_manager.set_backend_policy(BackendType.SHELL, security_policy)
async def ensure_initialized(self) -> None:
"""
Internal helper. Guarantee that `initialize()` has been executed
"""
if not self.is_initialized:
await self.initialize()
@abstractmethod
async def initialize(self) -> None:
"""Initialize provider, call `create_session` to create all sessions if not exist
Subclasses should set `self.is_initialized = True` after successful initialization
"""
pass
@abstractmethod
async def create_session(self, session_config: SessionConfig) -> TSession:
"""Create session, update _sessions"""
pass
@abstractmethod
async def close_session(self, session_name: str) -> None:
"""Close session"""
pass
def list_sessions(self) -> List[str]:
"""Get all session IDs"""
return list(self._sessions.keys())
def get_session(self, session_name: str) -> Optional[TSession]:
"""Get session object by ID"""
return self._sessions.get(session_name)
async def close_all_sessions(self) -> None:
"""Provider shutdown cleanup"""
for session_name in list(self._sessions.keys()):
try:
await self.close_session(session_name)
except Exception as e:
print(f"Error closing session {session_name}: {e}")
self._sessions.clear()
self.is_initialized = False
def __repr__(self) -> str:
return (f"Provider(backend={self.backend_type.value}, "
f"initialized={self.is_initialized}, "
f"sessions={len(self._sessions)}, "
f"config_items={len(self.config)})")
async def list_tools(self, session_name: Optional[str] = None) -> List[BaseTool]:
"""
Return BaseTool list.
If session_name is specified, only return the tools of the specified session.
If session_name is not specified, return all tools of all sessions.
"""
await self.ensure_initialized()
if session_name:
session = self._sessions.get(session_name)
return await session.list_tools() if session else []
tools: list[BaseTool] = []
for sess in self._sessions.values():
tools.extend(await sess.list_tools())
return tools
async def call_tool(
self,
session_name: str,
tool_name: str,
parameters: Dict[str, Any] | None = None,
) -> ToolResult:
await self.ensure_initialized()
parameters = parameters or {}
session = self._sessions.get(session_name)
if session is None:
return ToolResult(
status=ToolStatus.ERROR,
content="",
error=f"Session '{session_name}' not found",
metadata={"session_name": session_name, "tool_name": tool_name},
)
try:
return await session.call_tool(tool_name, parameters)
except Exception as e:
logger.error("Execute tool error: %s @%s - %s", tool_name, session_name, e)
return ToolResult(
status=ToolStatus.ERROR,
content="",
error=str(e),
metadata={"session_name": session_name, "tool_name": tool_name},
)
class ProviderRegistry:
"""
Maintain mapping of BackendType -> Provider, and provide dynamic registration / retrieval capabilities
"""
def __init__(self) -> None:
self._providers: dict[BackendType, Provider] = {}
def register(self, provider: "Provider") -> None:
self._providers[provider.backend_type] = provider
logger.debug("Provider for %s registered", provider.backend_type)
def get(self, backend: BackendType) -> "Provider":
if backend not in self._providers:
raise KeyError(f"Provider for '{backend.value}' not registered")
return self._providers[backend]
def list(self) -> dict[BackendType, "Provider"]:
return dict(self._providers)