| """ |
| GAIA Agent with Essential Tools for 30%+ Accuracy |
| Built with LangGraph and Groq LLM |
| """ |
| import os |
| import re |
| import json |
| from typing import Annotated |
| from langchain_core.tools import tool |
| from langchain_core.messages import SystemMessage |
| from langchain_community.tools.tavily_search import TavilySearchResults |
| from langchain_community.document_loaders import WikipediaLoader |
| from langchain_groq import ChatGroq |
| from langgraph.graph import StateGraph, MessagesState, START, END |
| from langgraph.prebuilt import ToolNode, tools_condition |
| from langgraph.checkpoint.memory import MemorySaver |
|
|
| |
| def get_llm(): |
| """Get Groq LLM instance""" |
| return ChatGroq( |
| model="llama-3.1-8b-instant", |
| temperature=0, |
| max_tokens=8000, |
| timeout=60, |
| max_retries=2 |
| ) |
|
|
| |
| |
| |
|
|
| @tool |
| def web_search(query: str) -> str: |
| """ |
| Search the web for current information using Tavily. |
| Use this for finding recent information, facts, statistics, or any data not in your training. |
| |
| Args: |
| query: The search query string |
| |
| Returns: |
| Search results as formatted text |
| """ |
| try: |
| tavily = TavilySearchResults( |
| max_results=5, |
| search_depth="advanced", |
| include_answer=True, |
| include_raw_content=False |
| ) |
| results = tavily.invoke(query) |
| |
| if not results: |
| return "No results found." |
| |
| |
| formatted = [] |
| for i, result in enumerate(results, 1): |
| title = result.get('title', 'No title') |
| content = result.get('content', 'No content') |
| url = result.get('url', '') |
| formatted.append(f"Result {i}:\nTitle: {title}\nContent: {content}\nURL: {url}\n") |
| |
| return "\n".join(formatted) |
| except Exception as e: |
| return f"Error searching web: {str(e)}" |
|
|
|
|
| @tool |
| def wikipedia_search(query: str) -> str: |
| """ |
| Search Wikipedia for encyclopedic information. |
| Use this for historical facts, biographies, scientific concepts, etc. |
| |
| Args: |
| query: The Wikipedia search query |
| |
| Returns: |
| Wikipedia article content |
| """ |
| try: |
| loader = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=4000) |
| docs = loader.load() |
| |
| if not docs: |
| return f"No Wikipedia article found for '{query}'" |
| |
| |
| content = "\n\n---\n\n".join([doc.page_content for doc in docs]) |
| return f"Wikipedia results for '{query}':\n\n{content}" |
| except Exception as e: |
| return f"Error searching Wikipedia: {str(e)}" |
|
|
|
|
| @tool |
| def calculate(expression: str) -> str: |
| """ |
| Evaluate a mathematical expression safely. |
| Supports basic arithmetic: +, -, *, /, //, %, **, parentheses. |
| Also supports common math functions: abs, round, min, max, sum. |
| |
| Args: |
| expression: Mathematical expression as a string (e.g., "2 + 2", "sqrt(16)", "10 ** 2") |
| |
| Returns: |
| The calculated result |
| """ |
| try: |
| |
| import math |
| |
| |
| safe_dict = { |
| 'abs': abs, 'round': round, 'min': min, 'max': max, 'sum': sum, |
| 'sqrt': math.sqrt, 'pow': pow, 'log': math.log, 'log10': math.log10, |
| 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, |
| 'pi': math.pi, 'e': math.e, 'ceil': math.ceil, 'floor': math.floor |
| } |
| |
| |
| expression = expression.strip() |
| |
| |
| result = eval(expression, {"__builtins__": {}}, safe_dict) |
| return str(result) |
| except Exception as e: |
| return f"Error calculating '{expression}': {str(e)}" |
|
|
|
|
| @tool |
| def python_executor(code: str) -> str: |
| """ |
| Execute Python code safely for data processing and calculations. |
| Use this for complex calculations, data manipulation, or multi-step computations. |
| The code should print its output. |
| |
| Args: |
| code: Python code to execute |
| |
| Returns: |
| The output of the code execution |
| """ |
| try: |
| import io |
| import sys |
| import math |
| import json |
| from datetime import datetime, timedelta |
| |
| |
| old_stdout = sys.stdout |
| sys.stdout = buffer = io.StringIO() |
| |
| |
| safe_globals = { |
| '__builtins__': { |
| 'print': print, 'len': len, 'range': range, 'str': str, |
| 'int': int, 'float': float, 'list': list, 'dict': dict, |
| 'set': set, 'tuple': tuple, 'sorted': sorted, 'sum': sum, |
| 'min': min, 'max': max, 'abs': abs, 'round': round, |
| 'enumerate': enumerate, 'zip': zip, 'map': map, 'filter': filter, |
| }, |
| 'math': math, |
| 'json': json, |
| 'datetime': datetime, |
| 'timedelta': timedelta, |
| } |
| |
| |
| exec(code, safe_globals) |
| |
| |
| sys.stdout = old_stdout |
| output = buffer.getvalue() |
| |
| return output if output else "Code executed successfully (no output)" |
| except Exception as e: |
| sys.stdout = old_stdout |
| return f"Error executing code: {str(e)}" |
|
|
|
|
| @tool |
| def read_file(filepath: str) -> str: |
| """ |
| Read and return the contents of a file. |
| Supports text files, CSV, JSON, and basic file formats. |
| |
| Args: |
| filepath: Path to the file to read |
| |
| Returns: |
| File contents as string |
| """ |
| try: |
| |
| if not os.path.exists(filepath): |
| return f"File not found: {filepath}" |
| |
| |
| if filepath.endswith('.json'): |
| with open(filepath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| return json.dumps(data, indent=2) |
| |
| elif filepath.endswith('.csv'): |
| try: |
| import pandas as pd |
| df = pd.read_csv(filepath) |
| return f"CSV file with {len(df)} rows and {len(df.columns)} columns:\n\n{df.to_string()}" |
| except ImportError: |
| |
| with open(filepath, 'r', encoding='utf-8') as f: |
| return f.read() |
| |
| else: |
| |
| with open(filepath, 'r', encoding='utf-8') as f: |
| content = f.read() |
| return content |
| except Exception as e: |
| return f"Error reading file '{filepath}': {str(e)}" |
|
|
|
|
| |
| |
| |
|
|
| GAIA_SYSTEM_PROMPT = """You are a helpful AI assistant designed to answer questions from the GAIA benchmark. |
| |
| CRITICAL ANSWER FORMAT RULES: |
| 1. For numbers: NO commas, NO units (unless explicitly requested) |
| - CORRECT: "1000" or "1000 meters" (if units requested) |
| - WRONG: "1,000" or "1000 meters" (if units not requested) |
| |
| 2. For text answers: No articles (a, an, the), no abbreviations |
| - CORRECT: "United States" |
| - WRONG: "The United States" or "USA" |
| |
| 3. For lists: Comma-separated with one space after each comma |
| - CORRECT: "apple, banana, orange" |
| - WRONG: "apple,banana,orange" or "apple, banana, orange." |
| |
| 4. For dates: Use the format specified in the question |
| - If not specified, use ISO format: YYYY-MM-DD |
| |
| 5. Be precise and concise - answer ONLY what is asked |
| |
| APPROACH: |
| 1. Read the question carefully and identify what information is needed |
| 2. Use tools to gather information (web search, Wikipedia, calculations) |
| 3. For multi-step questions, break down the problem and solve step by step |
| 4. Verify your answer matches the format requirements above |
| 5. Return ONLY the final answer in the correct format |
| |
| AVAILABLE TOOLS: |
| - web_search: Search the internet for current information |
| - wikipedia_search: Search Wikipedia for encyclopedic knowledge |
| - calculate: Perform mathematical calculations |
| - python_executor: Execute Python code for complex computations |
| - read_file: Read files (CSV, JSON, text) |
| |
| Remember: Your final response should be ONLY the answer in the correct format, nothing else. |
| """ |
|
|
| |
| |
| |
|
|
| def build_graph(): |
| """Build the LangGraph agent with tools""" |
| |
| |
| llm = get_llm() |
| |
| |
| tools = [ |
| web_search, |
| wikipedia_search, |
| calculate, |
| python_executor, |
| read_file |
| ] |
| |
| |
| llm_with_tools = llm.bind_tools(tools) |
| |
| |
| def assistant(state: MessagesState): |
| """Assistant node that calls the LLM""" |
| messages = state["messages"] |
| |
| |
| if not any(isinstance(msg, SystemMessage) for msg in messages): |
| messages = [SystemMessage(content=GAIA_SYSTEM_PROMPT)] + messages |
| |
| response = llm_with_tools.invoke(messages) |
| return {"messages": [response]} |
| |
| |
| builder = StateGraph(MessagesState) |
| |
| |
| builder.add_node("assistant", assistant) |
| builder.add_node("tools", ToolNode(tools)) |
| |
| |
| builder.add_edge(START, "assistant") |
| builder.add_conditional_edges( |
| "assistant", |
| tools_condition, |
| ) |
| builder.add_edge("tools", "assistant") |
| |
| |
| memory = MemorySaver() |
| graph = builder.compile(checkpointer=memory) |
| |
| return graph |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| """Test the agent with sample questions""" |
| from langchain_core.messages import HumanMessage |
| |
| |
| print("Building agent...") |
| agent = build_graph() |
| |
| |
| test_questions = [ |
| "What is 25 * 4 + 100?", |
| "Who was the first president of the United States?", |
| "Search for the population of Tokyo in 2024" |
| ] |
| |
| for i, question in enumerate(test_questions, 1): |
| print(f"\n{'='*60}") |
| print(f"Test {i}: {question}") |
| print('='*60) |
| |
| try: |
| config = {"configurable": {"thread_id": f"test_{i}"}} |
| result = agent.invoke( |
| {"messages": [HumanMessage(content=question)]}, |
| config=config |
| ) |
| answer = result['messages'][-1].content |
| print(f"Answer: {answer}") |
| except Exception as e: |
| print(f"Error: {e}") |
|
|
|
|