| | """ |
| | Simple test script for the GAIA agent |
| | """ |
| | import os |
| | from dotenv import load_dotenv |
| | from langchain_core.messages import HumanMessage |
| | from agent import build_graph |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | print("Checking API keys...") |
| | groq_key = os.getenv("GROQ_API_KEY") |
| | tavily_key = os.getenv("TAVILY_API_KEY") |
| |
|
| | if not groq_key: |
| | print("β GROQ_API_KEY not found in environment") |
| | else: |
| | print(f"β
GROQ_API_KEY found: {groq_key[:10]}...") |
| |
|
| | if not tavily_key: |
| | print("β TAVILY_API_KEY not found in environment") |
| | else: |
| | print(f"β
TAVILY_API_KEY found: {tavily_key[:10]}...") |
| |
|
| | print("\n" + "="*60) |
| | print("Building agent...") |
| | print("="*60) |
| |
|
| | try: |
| | agent = build_graph() |
| | print("β
Agent built successfully!") |
| | except Exception as e: |
| | print(f"β Error building agent: {e}") |
| | exit(1) |
| |
|
| | |
| | test_questions = [ |
| | { |
| | "question": "What is 25 * 4?", |
| | "expected_type": "number", |
| | "description": "Simple calculation test" |
| | }, |
| | { |
| | "question": "Who was the first president of the United States? Answer with just the name.", |
| | "expected_type": "text", |
| | "description": "Simple knowledge test" |
| | } |
| | ] |
| |
|
| | print("\n" + "="*60) |
| | print("Running tests...") |
| | print("="*60) |
| |
|
| | for i, test in enumerate(test_questions, 1): |
| | print(f"\n{'='*60}") |
| | print(f"Test {i}: {test['description']}") |
| | print(f"Question: {test['question']}") |
| | print('='*60) |
| | |
| | try: |
| | config = {"configurable": {"thread_id": f"test_{i}"}} |
| | result = agent.invoke( |
| | {"messages": [HumanMessage(content=test['question'])]}, |
| | config=config |
| | ) |
| | answer = result['messages'][-1].content |
| | |
| | |
| | if "Final Answer:" in answer: |
| | answer = answer.split("Final Answer:")[-1].strip() |
| | |
| | print(f"β
Answer: {answer}") |
| | |
| | except Exception as e: |
| | print(f"β Error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | print("\n" + "="*60) |
| | print("Tests completed!") |
| | print("="*60) |
| |
|
| |
|