from typing import Annotated, TypedDict from langgraph.graph.message import add_messages from langchain_core.messages import HumanMessage, AIMessage, AnyMessage, SystemMessage from langgraph.prebuilt import ToolNode, tools_condition from langgraph.graph import START, StateGraph from langchain_openai import ChatOpenAI from tools import all_tools import inspect import os import re # 1. Setup once OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: raise ValueError("Missing OPENAI_API_KEY environment variable.") chat = ChatOpenAI( model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY, temperature=0, ) chat_with_tools = chat.bind_tools(all_tools) # 2. Define the agent state class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] def extract_gaia_answer(text: str) -> str: """ Extracts just the final answer in raw form, stripping explanation and prefixes like: - 'The answer is: ...' - 'Answer: ...' - Or just the raw line if short and valid. """ patterns = [ r"The answer is:\s*(.+)", r"Answer:\s*(.+)", r"^([a-z0-9\s,\-]+)$", # simple raw line (numbers, text) ] for pattern in patterns: match = re.search(pattern, text.strip(), re.IGNORECASE | re.MULTILINE) if match: return match.group(1).strip().lower() # Fallback: return first short line if it's probably the answer lines = [l.strip() for l in text.strip().splitlines() if l.strip()] if lines and len(lines[0]) < 80: return lines[0].strip().lower() # Final fallback: return full text, lowercase return text.strip().lower() # 3. Assistant node def assistant(state: AgentState): tool_descriptions = "\n".join([ f"{tool.name}{inspect.signature(tool.func)}:\n {tool.description.strip()}" for tool in all_tools ]) sys_msg = SystemMessage( content=( "You are a helpful AI assistant who solves GAIA benchmark questions using step-by-step reasoning.\n" "Before answering, always think out loud and plan your approach.\n" "Use tools when you lack information or need external data. Only use audio or transcription tools if the user clearly provides or references an audio file.\n" "Do not assume the existence of files or media unless they are explicitly mentioned. Do not call tools like transcription unless an actual file or media reference is present.\n" "After every tool call, always analyze the result and continue reasoning to arrive at a final answer.\n" "If the question is unclear, incomplete, or missing context, respond with: **'The question is incomplete — please provide more information.'**" "Never treat tool outputs as final — interpret them and continue solving the task step-by-step.\n" "When the question specifies an answer format (e.g., a number, list, or code), respond **only with the final answer** in the required format. Do not add explanations, quotes, or set notation. Output exactly what is requested.\n" "Finish with a clear and concise answer, such as 'The answer is: right'.\n" "\nAvailable tools:\n" f"{tool_descriptions}" ) ) input_msgs = [sys_msg] + state["messages"] print("\n🧠 Assistant received messages:") for msg in input_msgs: print(f"šŸ”¹ {msg.__class__.__name__}: {getattr(msg, 'content', '')[:200]}") output = chat_with_tools.invoke(input_msgs) print("\nšŸ—£ļø Assistant response:") print("-" * 40) print(getattr(output, 'content', '')[:500]) print("-" * 40) return { "messages": [output], } # 4. Build the agent graph def build_graph(max_steps: int = 5): builder = StateGraph(AgentState) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(all_tools)) builder.add_edge(START, "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") graph = builder.compile() def limited_invoke(state, max_steps: int = 5, max_reasoning_steps_after_tool: int = 2): steps = 0 reasoning_steps_since_last_tool = 0 while steps < max_steps: print(f"\U0001f501 Step {steps + 1}") state = graph.invoke(state) for msg in state["messages"]: if isinstance(msg, AIMessage): print("\nšŸ¤– Assistant says:") print("-" * 40) print(msg.content.strip()) print("-" * 40) latest_message = state["messages"][-1] if state["messages"] else None if isinstance(latest_message, AIMessage): if latest_message.tool_calls: print("šŸ”„ Tool call detected — continuing loop.") reasoning_steps_since_last_tool = 0 # reset counter else: reasoning_steps_since_last_tool += 1 print(f"🧠 No tool call — reasoning step #{reasoning_steps_since_last_tool}") # šŸ› ļø Handle reverse_sentence manually if "reverse_sentence" in latest_message.content.lower(): # Try to find the ToolMessage output tool_outputs = [msg for msg in state["messages"] if msg.type == "tool"] if tool_outputs: reversed_text = tool_outputs[-1].content.strip() print(f"šŸ” Re-feeding reversed message:\n{reversed_text}") state["messages"].append(HumanMessage(content=reversed_text)) continue # loop again with new input if reasoning_steps_since_last_tool >= max_reasoning_steps_after_tool: print("āœ… Final answer assumed after sufficient reasoning.") break steps += 1 return state return limited_invoke # 5. BasicAgent class # class BasicAgent: # def __init__(self, max_steps: int = 5): # self.graph = build_graph(max_steps) # def __call__(self, question: str) -> str: # response = self.graph({"messages": [HumanMessage(content=question)]}) # if response.get("messages"): # final_message = response["messages"][-1] # return final_message.content if hasattr(final_message, "content") else "No final message." # else: # return "No response." class BasicAgent: def __init__(self, max_steps: int = 5): self.graph = build_graph(max_steps) def __call__(self, question: str) -> str: response = self.graph({"messages": [HumanMessage(content=question)]}) if response.get("messages"): final_message = response["messages"][-1] raw_content = final_message.content if hasattr(final_message, "content") else "No final message." return extract_gaia_answer(raw_content) else: return "No response." if __name__ == "__main__": agent = BasicAgent() print(agent("What is the capital of France?"))