chinmayjha's picture
Improve agent output formatting with inline citations and full sources
8c6064d unverified
from pathlib import Path
from typing import Any
import opik
from loguru import logger
from opik import opik_context
from smolagents import LiteLLMModel, MessageRole, MultiStepAgent, ToolCallingAgent
from second_brain_online.config import settings
from .tools import (
HuggingFaceEndpointSummarizerTool,
MongoDBRetrieverTool,
OpenAISummarizerTool,
what_can_i_do,
)
def get_agent(retriever_config_path: Path) -> "AgentWrapper":
agent = AgentWrapper.build_from_smolagents(
retriever_config_path=retriever_config_path
)
return agent
class AgentWrapper:
def __init__(self, agent: MultiStepAgent) -> None:
self.__agent = agent
@property
def input_messages(self) -> list[dict]:
return self.__agent.input_messages
@property
def agent_name(self) -> str:
return self.__agent.agent_name
@property
def max_steps(self) -> str:
return self.__agent.max_steps
@classmethod
def build_from_smolagents(cls, retriever_config_path: Path) -> "AgentWrapper":
retriever_tool = MongoDBRetrieverTool(config_path=retriever_config_path)
if settings.USE_HUGGINGFACE_DEDICATED_ENDPOINT:
logger.warning(
f"Using Hugging Face dedicated endpoint as the summarizer with URL: {settings.HUGGINGFACE_DEDICATED_ENDPOINT}"
)
summarizer_tool = HuggingFaceEndpointSummarizerTool()
else:
logger.warning(
f"Using OpenAI as the summarizer with model: {settings.OPENAI_MODEL_ID}"
)
summarizer_tool = OpenAISummarizerTool(stream=False)
model = LiteLLMModel(
model_id=settings.OPENAI_MODEL_ID,
api_base="https://api.openai.com/v1",
api_key=settings.OPENAI_API_KEY,
)
agent = ToolCallingAgent(
tools=[what_can_i_do, retriever_tool, summarizer_tool],
model=model,
max_steps=4, # Allow more steps for complex queries
verbosity_level=2,
)
return cls(agent)
@opik.track(name="Agent.run")
def run(self, task: str, **kwargs) -> Any:
result = self.__agent.run(task, return_full_result=True, **kwargs)
# Debug: Print step structure to understand the data
logger.info(f"Result type: {type(result)}")
if hasattr(result, 'steps'):
logger.info(f"Number of steps: {len(result.steps)}")
for i, step in enumerate(result.steps):
logger.info(f"Step {i}: type={type(step)}, keys={step.keys() if isinstance(step, dict) else 'not a dict'}")
if isinstance(step, dict) and 'tool_calls' in step:
logger.info(f" Tool calls: {step['tool_calls']}")
if step['tool_calls']:
for tc in step['tool_calls']:
tc_type = type(tc)
if isinstance(tc, dict):
logger.info(f" Tool call dict: {tc}")
else:
logger.info(f" Tool call object: {tc}, type: {tc_type}")
if hasattr(tc, 'function'):
logger.info(f" Function: {tc.function}")
if hasattr(tc, 'name'):
logger.info(f" Name: {tc.name}")
# Extract the raw output from answer_with_sources (Step 2) instead of using final_answer
if hasattr(result, 'steps') and len(result.steps) >= 2:
# Find the step where answer_with_sources was called
for step_idx, step in enumerate(result.steps):
if isinstance(step, dict) and 'tool_calls' in step and step['tool_calls']:
for tool_call in step['tool_calls']:
# Handle both dict and object formats
tool_name = None
if isinstance(tool_call, dict):
tool_name = tool_call.get('function', {}).get('name')
elif hasattr(tool_call, 'function'):
if hasattr(tool_call.function, 'name'):
tool_name = tool_call.function.name
elif hasattr(tool_call, 'name'):
tool_name = tool_call.name
if tool_name == 'answer_with_sources':
# Found the answer_with_sources step - return its observations
if 'observations' in step and step['observations']:
logger.info(f"✅ Found answer_with_sources at step {step_idx}, returning its observations")
return step['observations']
# Fallback to regular result.output
logger.warning("⚠️ answer_with_sources output not found, falling back to result.output")
if hasattr(result, 'output'):
return result.output
return result
def extract_tool_responses(agent: ToolCallingAgent) -> str:
"""
Extracts and concatenates all tool response contents with numbered observation delimiters.
Args:
input_messages (List[Dict]): List of message dictionaries containing 'role' and 'content' keys
Returns:
str: Tool response contents separated by numbered observation delimiters
Example:
>>> messages = [
... {"role": MessageRole.TOOL_RESPONSE, "content": "First response"},
... {"role": MessageRole.USER, "content": "Question"},
... {"role": MessageRole.TOOL_RESPONSE, "content": "Second response"}
... ]
>>> extract_tool_responses(messages)
"-------- OBSERVATION 1 --------\nFirst response\n-------- OBSERVATION 2 --------\nSecond response"
"""
tool_responses = [
msg["content"]
for msg in agent.input_messages
if msg["role"] == MessageRole.TOOL_RESPONSE
]
return "\n".join(
f"-------- OBSERVATION {i + 1} --------\n{response}"
for i, response in enumerate(tool_responses)
)
class OpikAgentMonitorCallback:
def __init__(self) -> None:
self.output_state: dict = {}
def __call__(self, step_log) -> None:
input_state = {
"agent_memory": step_log.agent_memory,
"tool_calls": step_log.tool_calls,
}
self.output_state = {"observations": step_log.observations}
self.trace(input_state)
@opik.track(name="Callback.agent_step")
def trace(self, step_log) -> dict:
return self.output_state