Spaces:
Runtime error
Runtime error
File size: 4,823 Bytes
4962437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from __future__ import annotations
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type, Tuple
from pydantic import BaseModel
from langchain.llms.base import BaseLLM
from langchain.agents.agent import AgentExecutor
from langchain.agents import load_tools
class ToolScope(Enum):
GLOBAL = "global"
SESSION = "session"
class ToolException(Exception):
pass
class BaseTool(ABC):
name: str
description: str
@abstractmethod
def run(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
async def arun(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.run(*args, **kwargs)
class Tool(BaseTool):
def __init__(self, name: str, description: str, func: Callable[..., Any]):
self.name = name
self.description = description
self.func = func
def run(self, *args: Any, **kwargs: Any) -> Any:
try:
return self.func(*args, **kwargs)
except ToolException as e:
raise e
async def arun(self, *args: Any, **kwargs: Any) -> Any:
try:
return await self.func(*args, **kwargs)
except ToolException as e:
raise e
class StructuredTool(BaseTool):
def __init__(
self,
name: str,
description: str,
args_schema: Type[BaseModel],
func: Callable[..., Any]
):
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
def run(self, *args: Any, **kwargs: Any) -> Any:
try:
return self.func(*args, **kwargs)
except ToolException as e:
raise e
async def arun(self, *args: Any, **kwargs: Any) -> Any:
try:
return await self.func(*args, **kwargs)
except ToolException as e:
raise e
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
class ToolWrapper:
def __init__(self, name: str, description: str, scope: ToolScope, func):
self.name = name
self.description = description
self.scope = scope
self.func = func
def is_global(self) -> bool:
return self.scope == ToolScope.GLOBAL
def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION
def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool:
if self.is_per_session():
self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session)
return Tool(name=self.name, description=self.description, func=self.func)
class BaseToolSet:
def tool_wrappers(cls) -> list[ToolWrapper]:
methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")]
return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods]
class ToolCreator(ABC):
@abstractmethod
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
pass
class GlobalToolsCreator(ToolCreator):
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_global=True,
)
)
return tools
class SessionToolsCreator(ToolCreator):
def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_per_session=True,
get_session=get_session,
)
)
return tools
class ToolsFactory:
@staticmethod
def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []):
return tool_creator.create_tools(toolsets, get_session)
@staticmethod
def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]:
return load_tools(toolnames, llm=llm)
|