Zack Zitting Bradshaw
Upload folder using huggingface_hub
4962437
from __future__ import annotations
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type, Tuple
from pydantic import BaseModel
from langchain.llms.base import BaseLLM
from langchain.agents.agent import AgentExecutor
from langchain.agents import load_tools
class ToolScope(Enum):
GLOBAL = "global"
SESSION = "session"
class ToolException(Exception):
pass
class BaseTool(ABC):
name: str
description: str
@abstractmethod
def run(self, *args: Any, **kwargs: Any) -> Any:
pass
@abstractmethod
async def arun(self, *args: Any, **kwargs: Any) -> Any:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.run(*args, **kwargs)
class Tool(BaseTool):
def __init__(self, name: str, description: str, func: Callable[..., Any]):
self.name = name
self.description = description
self.func = func
def run(self, *args: Any, **kwargs: Any) -> Any:
try:
return self.func(*args, **kwargs)
except ToolException as e:
raise e
async def arun(self, *args: Any, **kwargs: Any) -> Any:
try:
return await self.func(*args, **kwargs)
except ToolException as e:
raise e
class StructuredTool(BaseTool):
def __init__(
self,
name: str,
description: str,
args_schema: Type[BaseModel],
func: Callable[..., Any]
):
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
def run(self, *args: Any, **kwargs: Any) -> Any:
try:
return self.func(*args, **kwargs)
except ToolException as e:
raise e
async def arun(self, *args: Any, **kwargs: Any) -> Any:
try:
return await self.func(*args, **kwargs)
except ToolException as e:
raise e
SessionGetter = Callable[[], Tuple[str, AgentExecutor]]
class ToolWrapper:
def __init__(self, name: str, description: str, scope: ToolScope, func):
self.name = name
self.description = description
self.scope = scope
self.func = func
def is_global(self) -> bool:
return self.scope == ToolScope.GLOBAL
def is_per_session(self) -> bool:
return self.scope == ToolScope.SESSION
def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool:
if self.is_per_session():
self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session)
return Tool(name=self.name, description=self.description, func=self.func)
class BaseToolSet:
def tool_wrappers(cls) -> list[ToolWrapper]:
methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")]
return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods]
class ToolCreator(ABC):
@abstractmethod
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
pass
class GlobalToolsCreator(ToolCreator):
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_global=True,
)
)
return tools
class SessionToolsCreator(ToolCreator):
def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]:
tools = []
for toolset in toolsets:
tools.extend(
ToolsFactory.from_toolset(
toolset=toolset,
only_per_session=True,
get_session=get_session,
)
)
return tools
class ToolsFactory:
@staticmethod
def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]:
tools = []
for wrapper in toolset.tool_wrappers():
if only_global and not wrapper.is_global():
continue
if only_per_session and not wrapper.is_per_session():
continue
tools.append(wrapper.to_tool(get_session=get_session))
return tools
@staticmethod
def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []):
return tool_creator.create_tools(toolsets, get_session)
@staticmethod
def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]:
return load_tools(toolnames, llm=llm)