import json import asyncio from typing import AsyncGenerator, Dict, Any, List, Optional from llama_index.core.callbacks.base import BaseCallbackHandler from llama_index.core.callbacks.schema import CBEventType from llama_index.core.tools.types import ToolOutput from pydantic import BaseModel class CallbackEvent(BaseModel): event_type: CBEventType payload: Optional[Dict[str, Any]] = None event_id: str = "" def get_retrieval_message(self) -> dict | None: if self.payload: nodes = self.payload.get("nodes") if nodes: msg = f"Retrieved {len(nodes)} sources to use as context for the query" else: msg = f"Retrieving context for query: '{self.payload.get('query_str')}'" return { "type": "events", "data": {"title": msg}, } else: return None def get_tool_message(self) -> dict | None: func_call_args = self.payload.get("function_call") if func_call_args is not None and "tool" in self.payload: tool = self.payload.get("tool") return { "type": "events", "data": { "title": f"Calling tool: {tool.name} with inputs: {func_call_args}", }, } def _is_output_serializable(self, output: Any) -> bool: try: json.dumps(output) return True except TypeError: return False def get_agent_tool_response(self) -> dict | None: response = self.payload.get("response") if response is not None: sources = response.sources for source in sources: # Return the tool response here to include the toolCall information if isinstance(source, ToolOutput): if self._is_output_serializable(source.raw_output): output = source.raw_output else: output = source.content return { "type": "tools", "data": { "toolOutput": { "output": output, "isError": source.is_error, }, "toolCall": { "id": None, # There is no tool id in the ToolOutput "name": source.tool_name, "input": source.raw_input, }, }, } def to_response(self): match self.event_type: case "retrieve": return self.get_retrieval_message() case "function_call": return self.get_tool_message() case "agent_step": return self.get_agent_tool_response() case _: return None class EventCallbackHandler(BaseCallbackHandler): _aqueue: asyncio.Queue is_done: bool = False def __init__( self, ): """Initialize the base callback handler.""" ignored_events = [ CBEventType.CHUNKING, CBEventType.NODE_PARSING, CBEventType.EMBEDDING, CBEventType.LLM, CBEventType.TEMPLATING, ] super().__init__(ignored_events, ignored_events) self._aqueue = asyncio.Queue() def on_event_start( self, event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: str = "", **kwargs: Any, ) -> str: event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) if event.to_response() is not None: self._aqueue.put_nowait(event) def on_event_end( self, event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: str = "", **kwargs: Any, ) -> None: event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) if event.to_response() is not None: self._aqueue.put_nowait(event) def start_trace(self, trace_id: Optional[str] = None) -> None: """No-op.""" def end_trace( self, trace_id: Optional[str] = None, trace_map: Optional[Dict[str, List[str]]] = None, ) -> None: """No-op.""" async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]: while not self._aqueue.empty() or not self.is_done: try: yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) except asyncio.TimeoutError: pass