File size: 4,823 Bytes
4962437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import annotations

from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type, Tuple
from pydantic import BaseModel


from langchain.llms.base import BaseLLM
from langchain.agents.agent import AgentExecutor
from langchain.agents import load_tools

class ToolScope(Enum):
    GLOBAL = "global"
    SESSION = "session"


class ToolException(Exception):
    pass


class BaseTool(ABC):
    name: str
    description: str

    @abstractmethod
    def run(self, *args: Any, **kwargs: Any) -> Any:
        pass

    @abstractmethod
    async def arun(self, *args: Any, **kwargs: Any) -> Any:
        pass

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.run(*args, **kwargs)


class Tool(BaseTool):
    def __init__(self, name: str, description: str, func: Callable[..., Any]):
        self.name = name
        self.description = description
        self.func = func

    def run(self, *args: Any, **kwargs: Any) -> Any:
        try:
            return self.func(*args, **kwargs)
        except ToolException as e:
            raise e

    async def arun(self, *args: Any, **kwargs: Any) -> Any:
        try:
            return await self.func(*args, **kwargs)
        except ToolException as e:
            raise e


class StructuredTool(BaseTool):
    def __init__(
        self,
        name: str,
        description: str,
        args_schema: Type[BaseModel],
        func: Callable[..., Any]
    ):
        self.name = name
        self.description = description
        self.args_schema = args_schema
        self.func = func

    def run(self, *args: Any, **kwargs: Any) -> Any:
        try:
            return self.func(*args, **kwargs)
        except ToolException as e:
            raise e

    async def arun(self, *args: Any, **kwargs: Any) -> Any:
        try:
            return await self.func(*args, **kwargs)
        except ToolException as e:
            raise e


SessionGetter = Callable[[], Tuple[str, AgentExecutor]]


class ToolWrapper:
    def __init__(self, name: str, description: str, scope: ToolScope, func):
        self.name = name
        self.description = description
        self.scope = scope
        self.func = func

    def is_global(self) -> bool:
        return self.scope == ToolScope.GLOBAL

    def is_per_session(self) -> bool:
        return self.scope == ToolScope.SESSION

    def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool:
        if self.is_per_session():
            self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session)

        return Tool(name=self.name, description=self.description, func=self.func)


class BaseToolSet:
    def tool_wrappers(cls) -> list[ToolWrapper]:
        methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")]
        return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods]


class ToolCreator(ABC):
    @abstractmethod
    def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
        pass


class GlobalToolsCreator(ToolCreator):
    def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
        tools = []
        for toolset in toolsets:
            tools.extend(
                ToolsFactory.from_toolset(
                    toolset=toolset,
                    only_global=True,
                )
            )
        return tools


class SessionToolsCreator(ToolCreator):
    def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]:
        tools = []
        for toolset in toolsets:
            tools.extend(
                ToolsFactory.from_toolset(
                    toolset=toolset,
                    only_per_session=True,
                    get_session=get_session,
                )
            )
        return tools


class ToolsFactory:
    @staticmethod
    def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]:
        tools = []
        for wrapper in toolset.tool_wrappers():
            if only_global and not wrapper.is_global():
                continue
            if only_per_session and not wrapper.is_per_session():
                continue
            tools.append(wrapper.to_tool(get_session=get_session))
        return tools

    @staticmethod
    def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []):
        return tool_creator.create_tools(toolsets, get_session)

    @staticmethod
    def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]:
        return load_tools(toolnames, llm=llm)