Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Any | |
| import opik | |
| from loguru import logger | |
| from opik import opik_context | |
| from smolagents import LiteLLMModel, MessageRole, MultiStepAgent, ToolCallingAgent | |
| from second_brain_online.config import settings | |
| from .tools import ( | |
| HuggingFaceEndpointSummarizerTool, | |
| MongoDBRetrieverTool, | |
| OpenAISummarizerTool, | |
| what_can_i_do, | |
| ) | |
| def get_agent(retriever_config_path: Path) -> "AgentWrapper": | |
| agent = AgentWrapper.build_from_smolagents( | |
| retriever_config_path=retriever_config_path | |
| ) | |
| return agent | |
| class AgentWrapper: | |
| def __init__(self, agent: MultiStepAgent) -> None: | |
| self.__agent = agent | |
| def input_messages(self) -> list[dict]: | |
| return self.__agent.input_messages | |
| def agent_name(self) -> str: | |
| return self.__agent.agent_name | |
| def max_steps(self) -> str: | |
| return self.__agent.max_steps | |
| def build_from_smolagents(cls, retriever_config_path: Path) -> "AgentWrapper": | |
| retriever_tool = MongoDBRetrieverTool(config_path=retriever_config_path) | |
| if settings.USE_HUGGINGFACE_DEDICATED_ENDPOINT: | |
| logger.warning( | |
| f"Using Hugging Face dedicated endpoint as the summarizer with URL: {settings.HUGGINGFACE_DEDICATED_ENDPOINT}" | |
| ) | |
| summarizer_tool = HuggingFaceEndpointSummarizerTool() | |
| else: | |
| logger.warning( | |
| f"Using OpenAI as the summarizer with model: {settings.OPENAI_MODEL_ID}" | |
| ) | |
| summarizer_tool = OpenAISummarizerTool(stream=False) | |
| model = LiteLLMModel( | |
| model_id=settings.OPENAI_MODEL_ID, | |
| api_base="https://api.openai.com/v1", | |
| api_key=settings.OPENAI_API_KEY, | |
| ) | |
| agent = ToolCallingAgent( | |
| tools=[what_can_i_do, retriever_tool, summarizer_tool], | |
| model=model, | |
| max_steps=4, # Allow more steps for complex queries | |
| verbosity_level=2, | |
| ) | |
| return cls(agent) | |
| def run(self, task: str, **kwargs) -> Any: | |
| result = self.__agent.run(task, return_full_result=True, **kwargs) | |
| # Debug: Print step structure to understand the data | |
| logger.info(f"Result type: {type(result)}") | |
| if hasattr(result, 'steps'): | |
| logger.info(f"Number of steps: {len(result.steps)}") | |
| for i, step in enumerate(result.steps): | |
| logger.info(f"Step {i}: type={type(step)}, keys={step.keys() if isinstance(step, dict) else 'not a dict'}") | |
| if isinstance(step, dict) and 'tool_calls' in step: | |
| logger.info(f" Tool calls: {step['tool_calls']}") | |
| if step['tool_calls']: | |
| for tc in step['tool_calls']: | |
| tc_type = type(tc) | |
| if isinstance(tc, dict): | |
| logger.info(f" Tool call dict: {tc}") | |
| else: | |
| logger.info(f" Tool call object: {tc}, type: {tc_type}") | |
| if hasattr(tc, 'function'): | |
| logger.info(f" Function: {tc.function}") | |
| if hasattr(tc, 'name'): | |
| logger.info(f" Name: {tc.name}") | |
| # Extract the raw output from answer_with_sources (Step 2) instead of using final_answer | |
| if hasattr(result, 'steps') and len(result.steps) >= 2: | |
| # Find the step where answer_with_sources was called | |
| for step_idx, step in enumerate(result.steps): | |
| if isinstance(step, dict) and 'tool_calls' in step and step['tool_calls']: | |
| for tool_call in step['tool_calls']: | |
| # Handle both dict and object formats | |
| tool_name = None | |
| if isinstance(tool_call, dict): | |
| tool_name = tool_call.get('function', {}).get('name') | |
| elif hasattr(tool_call, 'function'): | |
| if hasattr(tool_call.function, 'name'): | |
| tool_name = tool_call.function.name | |
| elif hasattr(tool_call, 'name'): | |
| tool_name = tool_call.name | |
| if tool_name == 'answer_with_sources': | |
| # Found the answer_with_sources step - return its observations | |
| if 'observations' in step and step['observations']: | |
| logger.info(f"✅ Found answer_with_sources at step {step_idx}, returning its observations") | |
| return step['observations'] | |
| # Fallback to regular result.output | |
| logger.warning("⚠️ answer_with_sources output not found, falling back to result.output") | |
| if hasattr(result, 'output'): | |
| return result.output | |
| return result | |
| def extract_tool_responses(agent: ToolCallingAgent) -> str: | |
| """ | |
| Extracts and concatenates all tool response contents with numbered observation delimiters. | |
| Args: | |
| input_messages (List[Dict]): List of message dictionaries containing 'role' and 'content' keys | |
| Returns: | |
| str: Tool response contents separated by numbered observation delimiters | |
| Example: | |
| >>> messages = [ | |
| ... {"role": MessageRole.TOOL_RESPONSE, "content": "First response"}, | |
| ... {"role": MessageRole.USER, "content": "Question"}, | |
| ... {"role": MessageRole.TOOL_RESPONSE, "content": "Second response"} | |
| ... ] | |
| >>> extract_tool_responses(messages) | |
| "-------- OBSERVATION 1 --------\nFirst response\n-------- OBSERVATION 2 --------\nSecond response" | |
| """ | |
| tool_responses = [ | |
| msg["content"] | |
| for msg in agent.input_messages | |
| if msg["role"] == MessageRole.TOOL_RESPONSE | |
| ] | |
| return "\n".join( | |
| f"-------- OBSERVATION {i + 1} --------\n{response}" | |
| for i, response in enumerate(tool_responses) | |
| ) | |
| class OpikAgentMonitorCallback: | |
| def __init__(self) -> None: | |
| self.output_state: dict = {} | |
| def __call__(self, step_log) -> None: | |
| input_state = { | |
| "agent_memory": step_log.agent_memory, | |
| "tool_calls": step_log.tool_calls, | |
| } | |
| self.output_state = {"observations": step_log.observations} | |
| self.trace(input_state) | |
| def trace(self, step_log) -> dict: | |
| return self.output_state | |