Spaces:
Sleeping
Sleeping
import json | |
import asyncio | |
from typing import AsyncGenerator, Dict, Any, List, Optional | |
from llama_index.core.callbacks.base import BaseCallbackHandler | |
from llama_index.core.callbacks.schema import CBEventType | |
from llama_index.core.tools.types import ToolOutput | |
from pydantic import BaseModel | |
class CallbackEvent(BaseModel): | |
event_type: CBEventType | |
payload: Optional[Dict[str, Any]] = None | |
event_id: str = "" | |
def get_retrieval_message(self) -> dict | None: | |
if self.payload: | |
nodes = self.payload.get("nodes") | |
if nodes: | |
msg = f"Retrieved {len(nodes)} sources to use as context for the query" | |
else: | |
msg = f"Retrieving context for query: '{self.payload.get('query_str')}'" | |
return { | |
"type": "events", | |
"data": {"title": msg}, | |
} | |
else: | |
return None | |
def get_tool_message(self) -> dict | None: | |
func_call_args = self.payload.get("function_call") | |
if func_call_args is not None and "tool" in self.payload: | |
tool = self.payload.get("tool") | |
return { | |
"type": "events", | |
"data": { | |
"title": f"Calling tool: {tool.name} with inputs: {func_call_args}", | |
}, | |
} | |
def _is_output_serializable(self, output: Any) -> bool: | |
try: | |
json.dumps(output) | |
return True | |
except TypeError: | |
return False | |
def get_agent_tool_response(self) -> dict | None: | |
response = self.payload.get("response") | |
if response is not None: | |
sources = response.sources | |
for source in sources: | |
# Return the tool response here to include the toolCall information | |
if isinstance(source, ToolOutput): | |
if self._is_output_serializable(source.raw_output): | |
output = source.raw_output | |
else: | |
output = source.content | |
return { | |
"type": "tools", | |
"data": { | |
"toolOutput": { | |
"output": output, | |
"isError": source.is_error, | |
}, | |
"toolCall": { | |
"id": None, # There is no tool id in the ToolOutput | |
"name": source.tool_name, | |
"input": source.raw_input, | |
}, | |
}, | |
} | |
def to_response(self): | |
match self.event_type: | |
case "retrieve": | |
return self.get_retrieval_message() | |
case "function_call": | |
return self.get_tool_message() | |
case "agent_step": | |
return self.get_agent_tool_response() | |
case _: | |
return None | |
class EventCallbackHandler(BaseCallbackHandler): | |
_aqueue: asyncio.Queue | |
is_done: bool = False | |
def __init__( | |
self, | |
): | |
"""Initialize the base callback handler.""" | |
ignored_events = [ | |
CBEventType.CHUNKING, | |
CBEventType.NODE_PARSING, | |
CBEventType.EMBEDDING, | |
CBEventType.LLM, | |
CBEventType.TEMPLATING, | |
] | |
super().__init__(ignored_events, ignored_events) | |
self._aqueue = asyncio.Queue() | |
def on_event_start( | |
self, | |
event_type: CBEventType, | |
payload: Optional[Dict[str, Any]] = None, | |
event_id: str = "", | |
**kwargs: Any, | |
) -> str: | |
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) | |
if event.to_response() is not None: | |
self._aqueue.put_nowait(event) | |
def on_event_end( | |
self, | |
event_type: CBEventType, | |
payload: Optional[Dict[str, Any]] = None, | |
event_id: str = "", | |
**kwargs: Any, | |
) -> None: | |
event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) | |
if event.to_response() is not None: | |
self._aqueue.put_nowait(event) | |
def start_trace(self, trace_id: Optional[str] = None) -> None: | |
"""No-op.""" | |
def end_trace( | |
self, | |
trace_id: Optional[str] = None, | |
trace_map: Optional[Dict[str, List[str]]] = None, | |
) -> None: | |
"""No-op.""" | |
async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]: | |
while not self._aqueue.empty() or not self.is_done: | |
try: | |
yield await asyncio.wait_for(self._aqueue.get(), timeout=0.1) | |
except asyncio.TimeoutError: | |
pass | |