Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from enum import Enum | |
from abc import ABC, abstractmethod | |
from typing import Any, Callable, Optional, Type, Tuple | |
from pydantic import BaseModel | |
from langchain.llms.base import BaseLLM | |
from langchain.agents.agent import AgentExecutor | |
from langchain.agents import load_tools | |
class ToolScope(Enum): | |
GLOBAL = "global" | |
SESSION = "session" | |
class ToolException(Exception): | |
pass | |
class BaseTool(ABC): | |
name: str | |
description: str | |
def run(self, *args: Any, **kwargs: Any) -> Any: | |
pass | |
async def arun(self, *args: Any, **kwargs: Any) -> Any: | |
pass | |
def __call__(self, *args: Any, **kwargs: Any) -> Any: | |
return self.run(*args, **kwargs) | |
class Tool(BaseTool): | |
def __init__(self, name: str, description: str, func: Callable[..., Any]): | |
self.name = name | |
self.description = description | |
self.func = func | |
def run(self, *args: Any, **kwargs: Any) -> Any: | |
try: | |
return self.func(*args, **kwargs) | |
except ToolException as e: | |
raise e | |
async def arun(self, *args: Any, **kwargs: Any) -> Any: | |
try: | |
return await self.func(*args, **kwargs) | |
except ToolException as e: | |
raise e | |
class StructuredTool(BaseTool): | |
def __init__( | |
self, | |
name: str, | |
description: str, | |
args_schema: Type[BaseModel], | |
func: Callable[..., Any] | |
): | |
self.name = name | |
self.description = description | |
self.args_schema = args_schema | |
self.func = func | |
def run(self, *args: Any, **kwargs: Any) -> Any: | |
try: | |
return self.func(*args, **kwargs) | |
except ToolException as e: | |
raise e | |
async def arun(self, *args: Any, **kwargs: Any) -> Any: | |
try: | |
return await self.func(*args, **kwargs) | |
except ToolException as e: | |
raise e | |
SessionGetter = Callable[[], Tuple[str, AgentExecutor]] | |
class ToolWrapper: | |
def __init__(self, name: str, description: str, scope: ToolScope, func): | |
self.name = name | |
self.description = description | |
self.scope = scope | |
self.func = func | |
def is_global(self) -> bool: | |
return self.scope == ToolScope.GLOBAL | |
def is_per_session(self) -> bool: | |
return self.scope == ToolScope.SESSION | |
def to_tool(self, get_session: SessionGetter = lambda: []) -> BaseTool: | |
if self.is_per_session(): | |
self.func = lambda *args, **kwargs: self.func(*args, **kwargs, get_session=get_session) | |
return Tool(name=self.name, description=self.description, func=self.func) | |
class BaseToolSet: | |
def tool_wrappers(cls) -> list[ToolWrapper]: | |
methods = [getattr(cls, m) for m in dir(cls) if hasattr(getattr(cls, m), "is_tool")] | |
return [ToolWrapper(m.name, m.description, m.scope, m) for m in methods] | |
class ToolCreator(ABC): | |
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]: | |
pass | |
class GlobalToolsCreator(ToolCreator): | |
def create_tools(self, toolsets: list[BaseToolSet]) -> list[BaseTool]: | |
tools = [] | |
for toolset in toolsets: | |
tools.extend( | |
ToolsFactory.from_toolset( | |
toolset=toolset, | |
only_global=True, | |
) | |
) | |
return tools | |
class SessionToolsCreator(ToolCreator): | |
def create_tools(self, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []) -> list[BaseTool]: | |
tools = [] | |
for toolset in toolsets: | |
tools.extend( | |
ToolsFactory.from_toolset( | |
toolset=toolset, | |
only_per_session=True, | |
get_session=get_session, | |
) | |
) | |
return tools | |
class ToolsFactory: | |
def from_toolset(toolset: BaseToolSet, only_global: Optional[bool] = False, only_per_session: Optional[bool] = False, get_session: SessionGetter = lambda: []) -> list[BaseTool]: | |
tools = [] | |
for wrapper in toolset.tool_wrappers(): | |
if only_global and not wrapper.is_global(): | |
continue | |
if only_per_session and not wrapper.is_per_session(): | |
continue | |
tools.append(wrapper.to_tool(get_session=get_session)) | |
return tools | |
def create_tools(tool_creator: ToolCreator, toolsets: list[BaseToolSet], get_session: SessionGetter = lambda: []): | |
return tool_creator.create_tools(toolsets, get_session) | |
def create_global_tools_from_names(toolnames: list[str], llm: Optional[BaseLLM]) -> list[BaseTool]: | |
return load_tools(toolnames, llm=llm) | |