|
|
""" |
|
|
Base Tool Manager |
|
|
|
|
|
This module provides a general framework for managing different types of tools |
|
|
in the agentic pipeline. It can be extended for specific tool types. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, List, Any, Optional, Type, Callable |
|
|
from pathlib import Path |
|
|
import json |
|
|
import tempfile |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
from langchain_core.tools import BaseTool, tool |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolConfig: |
|
|
"""Configuration for a tool.""" |
|
|
name: str |
|
|
tool_type: str |
|
|
description: str |
|
|
fallback_enabled: bool = True |
|
|
|
|
|
class ToolStatus(Enum): |
|
|
"""Status of a tool.""" |
|
|
NOT_AVAILABLE = "not_available" |
|
|
AVAILABLE = "available" |
|
|
FALLBACK = "fallback" |
|
|
ERROR = "error" |
|
|
|
|
|
@dataclass |
|
|
class ToolConfigExtended(ToolConfig): |
|
|
"""Extended configuration for a tool.""" |
|
|
input_schema: Type[BaseModel] = None |
|
|
cache_results: bool = True |
|
|
|
|
|
|
|
|
class BaseToolManager(ABC): |
|
|
""" |
|
|
Abstract base class for tool managers. |
|
|
Provides common functionality for tool initialization and management. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: ToolConfigExtended): |
|
|
""" |
|
|
Initialize the tool manager. |
|
|
|
|
|
Args: |
|
|
config: Tool configuration |
|
|
""" |
|
|
self.config = config |
|
|
self.status = ToolStatus.NOT_AVAILABLE |
|
|
self.tool = None |
|
|
self.temp_dir = Path(tempfile.mkdtemp()) |
|
|
self.temp_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self._initialize_tool() |
|
|
|
|
|
@abstractmethod |
|
|
def _initialize_tool(self): |
|
|
"""Initialize the specific tool. Must be implemented by subclasses.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the tool instance. Must be implemented by subclasses.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool. Must be implemented by subclasses.""" |
|
|
pass |
|
|
|
|
|
def is_available(self) -> bool: |
|
|
"""Check if the tool is available.""" |
|
|
return self.status in [ToolStatus.AVAILABLE, ToolStatus.FALLBACK] |
|
|
|
|
|
def get_status(self) -> ToolStatus: |
|
|
"""Get current tool status.""" |
|
|
return self.status |
|
|
|
|
|
def get_tool(self) -> Optional[BaseTool]: |
|
|
"""Get the tool instance.""" |
|
|
return self.tool |
|
|
|
|
|
def get_tool_info(self) -> Dict[str, Any]: |
|
|
"""Get information about the tool.""" |
|
|
return { |
|
|
"name": self.config.name, |
|
|
"type": self.config.tool_type, |
|
|
"status": self.status.value, |
|
|
"description": self.config.description, |
|
|
"temp_dir": str(self.temp_dir) |
|
|
} |
|
|
|
|
|
def _set_status(self, status: ToolStatus): |
|
|
"""Set the tool status.""" |
|
|
self.status = status |
|
|
|
|
|
def _handle_error(self, error: Exception, fallback: bool = True): |
|
|
"""Handle errors during tool initialization.""" |
|
|
print(f"Error in {self.config.name}: {error}") |
|
|
self._set_status(ToolStatus.ERROR) |
|
|
|
|
|
if fallback and self.config.fallback_enabled: |
|
|
self._initialize_fallback() |
|
|
|
|
|
def _initialize_fallback(self): |
|
|
"""Initialize fallback tool when main tool fails.""" |
|
|
print(f"Initializing fallback for {self.config.name}") |
|
|
self._set_status(ToolStatus.FALLBACK) |
|
|
self.tool = self._create_fallback_tool() |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up resources.""" |
|
|
self.tool = None |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
|
|
|
class ToolRegistry: |
|
|
""" |
|
|
Registry for managing multiple tool managers. |
|
|
Provides a centralized way to access different tools. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self._managers: Dict[str, BaseToolManager] = {} |
|
|
self._configs: Dict[str, ToolConfigExtended] = {} |
|
|
self._tools: Dict[str, BaseTool] = {} |
|
|
|
|
|
def register_tool(self, name: str, manager: BaseToolManager, config: ToolConfigExtended): |
|
|
"""Register a tool manager.""" |
|
|
self._managers[name] = manager |
|
|
self._configs[name] = config |
|
|
self._tools[name] = manager.get_tool() |
|
|
print(f"Registered tool: {name}") |
|
|
|
|
|
def get_tool(self, name: str) -> Optional[BaseTool]: |
|
|
"""Get a tool by name.""" |
|
|
return self._tools.get(name) |
|
|
|
|
|
def get_manager(self, name: str) -> Optional[BaseToolManager]: |
|
|
"""Get a tool manager by name.""" |
|
|
return self._managers.get(name) |
|
|
|
|
|
def get_available_tools(self) -> List[str]: |
|
|
"""Get list of available tool names.""" |
|
|
return [name for name, manager in self._managers.items() if manager.is_available()] |
|
|
|
|
|
def get_tool_info(self, name: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get information about a specific tool.""" |
|
|
manager = self._managers.get(name) |
|
|
if manager: |
|
|
return manager.get_tool_info() |
|
|
return None |
|
|
|
|
|
def get_all_tools(self) -> Dict[str, BaseTool]: |
|
|
"""Get all available tools.""" |
|
|
return {name: tool for name, tool in self._tools.items() if tool is not None} |
|
|
|
|
|
def cleanup_all(self): |
|
|
"""Clean up all registered tools.""" |
|
|
for manager in self._managers.values(): |
|
|
manager.cleanup() |
|
|
self._managers.clear() |
|
|
self._configs.clear() |
|
|
self._tools.clear() |
|
|
|
|
|
|
|
|
|
|
|
tool_registry = ToolRegistry() |
|
|
|
|
|
|
|
|
def create_tool_wrapper(tool_name: str, description: str, input_schema: Type[BaseModel]) -> Callable: |
|
|
""" |
|
|
Create a tool wrapper function for LangChain integration. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool |
|
|
description: Description of the tool |
|
|
input_schema: Input schema for the tool |
|
|
|
|
|
Returns: |
|
|
Tool wrapper function |
|
|
""" |
|
|
def tool_wrapper(**kwargs) -> Dict[str, Any]: |
|
|
"""Tool wrapper function.""" |
|
|
tool = tool_registry.get_tool(tool_name) |
|
|
if tool is None: |
|
|
return {"error": f"Tool {tool_name} not available"} |
|
|
|
|
|
try: |
|
|
|
|
|
input_data = input_schema(**kwargs) |
|
|
result = tool.run(input_data.dict()) |
|
|
return result |
|
|
except Exception as e: |
|
|
return {"error": f"Tool {tool_name} failed: {str(e)}"} |
|
|
|
|
|
|
|
|
tool_wrapper.__name__ = tool_name |
|
|
tool_wrapper.__doc__ = description |
|
|
|
|
|
return tool_wrapper |
|
|
|