Spaces:
Sleeping
Sleeping
| import json | |
| import asyncio | |
| from typing import AsyncGenerator, Dict, Any, List, Optional | |
| from llama_index.core.callbacks.base import BaseCallbackHandler | |
| from llama_index.core.callbacks.schema import CBEventType | |
| from llama_index.core.tools.types import ToolOutput | |
| from pydantic import BaseModel | |
| class CallbackEvent(BaseModel): | |
| event_type: CBEventType | |
| payload: Optional[Dict[str, Any]] = None | |
| event_id: str = "" | |
| def get_retrieval_message(self) -> dict | None: | |
| if self.payload: | |
| nodes = self.payload.get("nodes") | |
| if nodes: | |
| msg = f"Retrieved {len(nodes)} sources to use as context for the query" | |
| else: | |
| msg = f"Retrieving context for query: '{self.payload.get('query_str')}'" | |
| return { | |
| "type": "events", | |
| "data": {"title": msg}, | |
| } | |
| else: | |
| return None | |
| def get_tool_message(self) -> dict | None: | |
| func_call_args = self.payload.get("function_call") | |
| if func_call_args is not None and "tool" in self.payload: | |
| tool = self.payload.get("tool") | |
| return { | |
| "type": "events", | |
| "data": { | |
| "title": f"Calling tool: {tool.name} with inputs: {func_call_args}", | |
| }, | |
| } | |
| def _is_output_serializable(self, output: Any) -> bool: | |
| try: | |
| json.dumps(output) | |
| return True | |
| except TypeError: | |
| return False | |
| def get_agent_tool_response(self) -> dict | None: | |
| response = self.payload.get("response") | |
| if response is not None: | |
| sources = response.sources | |
| for source in sources: | |
| # Return the tool response here to include the toolCall information | |
| if isinstance(source, ToolOutput): | |
| if self._is_output_serializable(source.raw_output): | |
| output = source.raw_output | |
| else: | |
| output = source.content | |
| return { | |
| "type": "tools", | |
| "data": { | |
| "toolOutput": { | |
| "output": output, | |
| "isError": source.is_error, | |
| }, | |
| "toolCall": { | |
| "id": None, # There is no tool id in the ToolOutput | |
| "name": source.tool_name, | |
| "input": source.raw_input, | |
| }, | |
| }, | |
| } | |
| def to_response(self): | |
| match self.event_type: | |
| case "retrieve": | |
| return self.get_retrieval_message() | |
| case "function_call": | |
| return self.get_tool_message() | |
| case "agent_step": | |
| return self.get_agent_tool_response() | |
| case _: | |
| return None | |
| class EventCallbackHandler(BaseCallbackHandler): | |
| _aqueue: asyncio.Queue | |
| is_done: bool = False | |
| def __init__( | |
| self, | |
| ): | |
| """Initialize the base callback handler.""" | |
| ignored_events = [ | |
| CBEventType.CHUNKING, | |
| CBEventType.NODE_PARSING, | |
| CBEventType.EMBEDDING, | |
| CBEventType.LLM, | |
| CBEventType.TEMPLATING, | |
| ] | |
| super().__init__(ignored_events, ignored_events) | |
| self._aqueue = asyncio.Queue() | |
| def on_event_start( | |
| self, | |
| event_type: CBEventType, | |
| payload: Optional[Dict[str, Any]] = None, | |
| event_id: str = "", | |
| **kwargs: Any, | |
| ) -> str: | |
| event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) | |
| if event.to_response() is not None: | |
| self._aqueue.put_nowait(event) | |
| def on_event_end( | |
| self, | |
| event_type: CBEventType, | |
| payload: Optional[Dict[str, Any]] = None, | |
| event_id: str = "", | |
| **kwargs: Any, | |
| ) -> None: | |
| event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) | |
| if event.to_response() is not None: | |
| self._aqueue.put_nowait(event) | |
| def start_trace(self, trace_id: Optional[str] = None) -> None: | |
| """No-op.""" | |
| def end_trace( | |
| self, | |
| trace_id: Optional[str] = None, | |
| trace_map: Optional[Dict[str, List[str]]] = None, | |
| ) -> None: | |
| """No-op.""" | |
| async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]: | |
| while not self._aqueue.empty() or not self.is_done: | |
| try: | |
| yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) | |
| except asyncio.TimeoutError: | |
| pass | |