Spaces:
Sleeping
Sleeping
from typing import Any, Dict | |
from langchain.callbacks.base import BaseCallbackHandler | |
from ..tools.cache_tools import CacheTools | |
from .cache.cache_handler import CacheHandler | |
class ToolsHandler(BaseCallbackHandler): | |
"""Callback handler for tool usage.""" | |
last_used_tool: Dict[str, Any] = {} | |
cache: CacheHandler = None | |
def __init__(self, cache: CacheHandler = None, **kwargs: Any): | |
"""Initialize the callback handler.""" | |
self.cache = cache | |
super().__init__(**kwargs) | |
def on_tool_start( | |
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
) -> Any: | |
"""Run when tool starts running.""" | |
name = serialized.get("name") | |
if name not in ["invalid_tool", "_Exception"]: | |
tools_usage = { | |
"tool": name, | |
"input": input_str, | |
} | |
self.last_used_tool = tools_usage | |
def on_tool_end(self, output: str, **kwargs: Any) -> Any: | |
"""Run when tool ends running.""" | |
if ( | |
"is not a valid tool" not in output | |
and "Invalid or incomplete response" not in output | |
and "Invalid Format" not in output | |
): | |
if self.last_used_tool["tool"] != CacheTools().name: | |
self.cache.add( | |
tool=self.last_used_tool["tool"], | |
input=self.last_used_tool["input"], | |
output=output, | |
) | |