|
from typing import Any, Dict |
|
|
|
from langchain.callbacks.base import BaseCallbackHandler |
|
|
|
from ..tools.cache_tools import CacheTools |
|
from .cache.cache_handler import CacheHandler |
|
|
|
|
|
class ToolsHandler(BaseCallbackHandler): |
|
"""Callback handler for tool usage.""" |
|
|
|
last_used_tool: Dict[str, Any] = {} |
|
cache: CacheHandler = None |
|
|
|
def __init__(self, cache: CacheHandler = None, **kwargs: Any): |
|
"""Initialize the callback handler.""" |
|
self.cache = cache |
|
super().__init__(**kwargs) |
|
|
|
def on_tool_start( |
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any |
|
) -> Any: |
|
"""Run when tool starts running.""" |
|
name = serialized.get("name") |
|
if name not in ["invalid_tool", "_Exception"]: |
|
tools_usage = { |
|
"tool": name, |
|
"input": input_str, |
|
} |
|
self.last_used_tool = tools_usage |
|
|
|
def on_tool_end(self, output: str, **kwargs: Any) -> Any: |
|
"""Run when tool ends running.""" |
|
if ( |
|
"is not a valid tool" not in output |
|
and "Invalid or incomplete response" not in output |
|
and "Invalid Format" not in output |
|
): |
|
if self.last_used_tool["tool"] != CacheTools().name: |
|
self.cache.add( |
|
tool=self.last_used_tool["tool"], |
|
input=self.last_used_tool["input"], |
|
output=output, |
|
) |
|
|