jfeng1115's picture
init commit
58d33f0
"""Base implementation for tools or skills."""
from abc import abstractmethod
from typing import Any, Optional
from pydantic import BaseModel, Extra, Field, validator
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class BaseTool(BaseModel):
"""Class responsible for defining a tool or skill for an LLM."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@abstractmethod
def _run(self, tool_input: str) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
def __call__(self, tool_input: str) -> str:
"""Make tools callable with str input."""
return self.run(tool_input)
def run(
self,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
) -> str:
"""Run the tool."""
if verbose is None:
verbose = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
observation = self._run(tool_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose, color=color, **kwargs
)
return observation
async def arun(
self,
tool_input: str,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any
) -> str:
"""Run the tool asynchronously."""
if verbose is None:
verbose = self.verbose
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input,
verbose=verbose,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
observation = await self._arun(tool_input)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose)
else:
self.callback_manager.on_tool_error(e, verbose=verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose, color=color, **kwargs
)
else:
self.callback_manager.on_tool_end(
observation, verbose=verbose, color=color, **kwargs
)
return observation