eaglelandsonce's picture
Upload 17 files
c631b4c
raw
history blame
1.46 kB
from typing import Any, Dict
from langchain.callbacks.base import BaseCallbackHandler
from ..tools.cache_tools import CacheTools
from .cache.cache_handler import CacheHandler
class ToolsHandler(BaseCallbackHandler):
"""Callback handler for tool usage."""
last_used_tool: Dict[str, Any] = {}
cache: CacheHandler = None
def __init__(self, cache: CacheHandler = None, **kwargs: Any):
"""Initialize the callback handler."""
self.cache = cache
super().__init__(**kwargs)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""
name = serialized.get("name")
if name not in ["invalid_tool", "_Exception"]:
tools_usage = {
"tool": name,
"input": input_str,
}
self.last_used_tool = tools_usage
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
if (
"is not a valid tool" not in output
and "Invalid or incomplete response" not in output
and "Invalid Format" not in output
):
if self.last_used_tool["tool"] != CacheTools().name:
self.cache.add(
tool=self.last_used_tool["tool"],
input=self.last_used_tool["input"],
output=output,
)